diff --git a/.bazelrc b/.bazelrc deleted file mode 100644 index d5d20309df82498a552df759e3d200a914a4cfb7..0000000000000000000000000000000000000000 --- a/.bazelrc +++ /dev/null @@ -1,88 +0,0 @@ -# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the -# target CPU to build transient dependencies correctly. See -# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu -build:android --crosstool_top=//external:android/crosstool -build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain -build:android_arm --config=android -build:android_arm --cpu=armeabi-v7a -build:android_arm --fat_apk_cpu=armeabi-v7a -build:android_arm64 --config=android -build:android_arm64 --cpu=arm64-v8a -build:android_arm64 --fat_apk_cpu=arm64-v8a - -# Config to use a mostly-static build and disable modular op registration -# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). -# By default, TensorFlow will build with a dependence on -# //tensorflow:libtensorflow_framework.so. -build:monolithic --define framework_shared_object=false - -# For projects which use TensorFlow as part of a Bazel build process, putting -# nothing in a bazelrc will default to a monolithic build. The following line -# opts in to modular op registration support by default. -build --define framework_shared_object=true - -# Please note that MKL on MacOS or windows is still not supported. -# If you would like to use a local MKL instead of downloading, please set the -# environment variable "TF_MKL_ROOT" every time before build. -build:mkl --define=build_with_mkl=true --define=enable_mkl=true -build:mkl -c opt - -# This config option is used to enable MKL-DNN open source library only, -# without depending on MKL binary version. -build:mkl_open_source_only --define=build_with_mkl_dnn_only=true -build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true - -build:download_clang --crosstool_top=@local_config_download_clang//:toolchain -build:download_clang --define=using_clang=true -# Instruct clang to use LLD for linking. -# This only works with GPU builds currently, since Bazel sets -B/usr/bin in -# auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over -# the downloaded one. -build:download_clang_use_lld --linkopt='-fuse-ld=lld' - -build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true - -build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain -build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true - -build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain -build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true - -build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl --define=using_sycl=true --define=using_trisycl=false - -build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE - -build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address - -build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true - -# Options extracted from configure script -build:gdr --define=with_gdr_support=true -build:ngraph --define=with_ngraph_support=true -build:verbs --define=with_verbs_support=true - -build --define=use_fast_cpp_protos=true -build --define=allow_oversize_protos=true -build --define=grpc_no_ares=true - -build --spawn_strategy=standalone -build --genrule_strategy=standalone -build -c opt - -# Other build flags. -build --define=grpc_no_ares=true - -# Modular TF build options -build:dynamic_kernels --define=dynamic_loaded_kernels=true - -# Default paths for TF_SYSTEM_LIBS -build --define=PREFIX=/usr -build --define=LIBDIR=$(PREFIX)/lib -build --define=INCLUDEDIR=$(PREFIX)/include - -# Do not commit the tf_configure.bazelrc line diff --git a/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..34ba4cf96017bb0dc15e74eee5d6ce211cf1058d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/00-bug-performance-issue.md @@ -0,0 +1,34 @@ +--- +name: Bug/Performance Issue +about: Use this template for reporting a bug or a performance issue. + +--- + +Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template + +**System information** +- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: +- TensorFlow installed from (source or binary): +- TensorFlow version (use command below): +- Python version: +- Bazel version (if compiling from source): +- GCC/Compiler version (if compiling from source): +- CUDA/cuDNN version: +- GPU model and memory: + + +You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) +You can also obtain the TensorFlow version with +python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" + +**Describe the current behavior** + +**Describe the expected behavior** + +**Code to reproduce the issue** +Provide a reproducible test case that is the bare minimum necessary to generate the problem. + +**Other info / logs** +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/10-build-installation-issue.md b/.github/ISSUE_TEMPLATE/10-build-installation-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..99c2fe61271fb51cce8aaf94d06d9d4a633aede4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/10-build-installation-issue.md @@ -0,0 +1,29 @@ +--- +name: Build/Installation Issue +about: Use this template for build/installation issues + +--- + +Please make sure that this is a build/installation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:build_template + +**System information** +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: +- TensorFlow installed from (source or binary): +- TensorFlow version: +- Python version: +- Installed using virtualenv? pip? conda?: +- Bazel version (if compiling from source): +- GCC/Compiler version (if compiling from source): +- CUDA/cuDNN version: +- GPU model and memory: + + + +**Describe the problem** + +**Provide the exact sequence of commands / steps that you executed before running into the problem** + + +**Any other info / logs** +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/20-documentation-issue.md b/.github/ISSUE_TEMPLATE/20-documentation-issue.md new file mode 100644 index 0000000000000000000000000000000000000000..7123ca6d6c507315dd3470e1813ac9dd17ba8fcd --- /dev/null +++ b/.github/ISSUE_TEMPLATE/20-documentation-issue.md @@ -0,0 +1,17 @@ +--- +name: Documentation Issue +about: Use this template for documentation related issues + +--- + +Please make sure that this is a documentation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:doc_template + + +**System information** +- TensorFlow version: +- Doc Link: + + +**Describe the documentation issue** + +**We welcome contributions by users. Will you be able to update submit a PR (use the [doc style guide](https://www.tensorflow.org/community/documentation)) to fix the doc Issue?** diff --git a/.github/ISSUE_TEMPLATE/30-feature-request.md b/.github/ISSUE_TEMPLATE/30-feature-request.md new file mode 100644 index 0000000000000000000000000000000000000000..71df2e5e49f9e42a23a8c453da5335cfbbbb6211 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/30-feature-request.md @@ -0,0 +1,22 @@ +--- +name: Feature Request +about: Use this template for raising a feature request + +--- + +Please make sure that this is a feature request. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template + + +**System information** +- TensorFlow version (you are using): +- Are you willing to contribute it (Yes/No): + + + +**Describe the feature and the current behavior/state.** + +**Will this change the current api? How?** + +**Who will benefit with this feature?** + +**Any Other info.** diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md new file mode 100644 index 0000000000000000000000000000000000000000..7b391279e479ade4ed5327728f19be8752e11507 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md @@ -0,0 +1,24 @@ +--- +name: TensorFlow Lite Op Request +about: Use this template for reporting ops you are using or missing. + +--- + + +**System information** +- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): +- TensorFlow installed from (source or binary): +- TensorFlow version (or github SHA if from source): + + +**Provide the text output from tflite_convert** + +``` +# Copy and paste here +``` + +Also, please include a link to a GraphDef or the model if possible. + +**Any other info / logs** + +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.github/ISSUE_TEMPLATE/50-other-issues.md b/.github/ISSUE_TEMPLATE/50-other-issues.md new file mode 100644 index 0000000000000000000000000000000000000000..2d78d9818bb69ebc7b0807afe5297051494c991e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/50-other-issues.md @@ -0,0 +1,13 @@ +--- +name: Other Issues +about: Use this template for any other non-support related issues + +--- + +This template is for miscellaneous issues not covered by the other issue categories. + +For questions on how to work with TensorFlow, or support for problems that are not verified bugs in TensorFlow, please go to [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow). + +If you are reporting a vulnerability, please use the [dedicated reporting process](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md). + +For high-level discussions about TensorFlow, please post to discuss@tensorflow.org, for questions about the development or internal workings of TensorFlow, or if you would like to know how to contribute to TensorFlow, please post to developers@tensorflow.org. diff --git a/.gitignore b/.gitignore index cb65f447d4a551266e237714a16d71b58bcfc51d..90324058600bee46af56e49028977971848a80de 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ .DS_Store .ipynb_checkpoints node_modules +/.bazelrc /.tf_configure.bazelrc /bazel-* /bazel_pip @@ -23,10 +24,10 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata -/tensorflow/contrib/lite/downloads/** -/tensorflow/contrib/lite/gen/** -/tensorflow/contrib/lite/examples/ios/simple/data/*.txt -/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite +/tensorflow/lite/tools/make/downloads/** +/tensorflow/lite/gen/** +/tensorflow/lite/examples/ios/simple/data/*.txt +/tensorflow/lite/examples/ios/simple/data/*.tflite xcuserdata/** /api_init_files_list.txt /estimator_api_init_files_list.txt diff --git a/BUILD b/BUILD index 4bf647e47aa56cff0b3fd5af7d5df99d8b70549b..1200cf5f7103cad12ab9693c339c372f4f3bc0fb 100644 --- a/BUILD +++ b/BUILD @@ -2,5 +2,7 @@ exports_files( [ "LICENSE", "ACKNOWLEDGEMENTS", + "configure", + "configure.py", ], ) diff --git a/CODEOWNERS b/CODEOWNERS index 94cc865479cd6ab5cdb589490d3a2d650f06b160..cb3fa2312405ce44d5dfc30ea4164740f436e07e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,6 +1,7 @@ # Where component owners are known, add them here. /tenosrflow/core/debug @caisq +/tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang /tensorflow/go @asimshankar @@ -46,18 +47,17 @@ /tensorflow/contrib/losses/ @alextp @ispirmustafa /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq /tensorflow/contrib/opt/ @strategist333 @alextp /tensorflow/contrib/pi_examples/ @maciekcc /tensorflow/contrib/quantization/ @petewarden /tensorflow/contrib/rnn/ @ebrevdo @scottzhu -/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenlavoie /tensorflow/contrib/seq2seq/ @ebrevdo @lmthang /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh /tensorflow/contrib/slim/ @sguada @thenbasilmanran /tensorflow/contrib/stateless/ @girving @alextp /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -/tensorflow/contrib/tensorrt/ @aaroey +/tensorflow/contrib/tensorrt/ @aaroey @smit-hinsu @azaks2 # NEED OWNER: /tensorflow/contrib/testing/ /tensorflow/contrib/timeseries/ @allenlavoie /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 5fff9d05a1c589636bc9c711e6eb7cc4aba86b2f..a4647020ff76830badd75f3d3f76a41a637159bb 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -7,19 +7,22 @@ In the interest of fostering an open and welcoming environment, we as contributo Examples of behavior that contributes to creating a positive environment include: -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members +* Using welcoming and inclusive language. +* Being respectful of differing viewpoints and experiences. +* Gracefully accepting constructive criticism. +* Focusing on what is best for the community. +* Showing empathy towards other community members. Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery and unwelcome sexual attention or advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic address, without explicit permission -* Conduct which could reasonably be considered inappropriate for the forum in which it occurs. +* The use of sexualized language or imagery and unwelcome sexual attention or + advances. +* Trolling, insulting/derogatory comments, and personal or political attacks. +* Public or private harassment. +* Publishing others' private information, such as a physical or electronic + address, without explicit permission. +* Conduct which could reasonably be considered inappropriate for the forum in + which it occurs. All TensorFlow forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable. @@ -48,10 +51,12 @@ However, for the vast majority of issues, we aim to empower individuals to first If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict: -1. Address the perceived conflict directly with those involved, preferably in a real-time medium. -2. If this fails, get a third party (e.g. a mutual friend, and/or someone with background on the issue, but not involved in conflict) to intercede. -3. If you are still unable to resolve the conflict, and you believe it rises to harassment or another code of conduct violation, report it. - +1. Address the perceived conflict directly with those involved, preferably in a + real-time medium. +2. If this fails, get a third party (e.g. a mutual friend, and/or someone with + background on the issue, but not involved in the conflict) to intercede. +3. If you are still unable to resolve the conflict, and you believe it rises to + harassment or another code of conduct violation, report it. ## Reporting Violations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f598999f351c10f8bd01dfbd3ad8897f19d570e8..4a296f265f7b9521c46d350cec26ff199f43eb6c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,8 +31,12 @@ Follow either of the two links above to access the appropriate CLA and instructi If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). -TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests. -For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally. +TensorFlow team members will be assigned to review your pull requests. Once the +pull requests are approved and pass continuous integration checks, a TensorFlow +team member will apply `ready to pull` label to your change. This means we are +working on getting your pull request submitted to our internal repository. After +the change has been submitted internally, your pull request will be merged +automatically on GitHub. If you want to contribute but you're not sure where to start, take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). diff --git a/ISSUES.md b/ISSUES.md new file mode 100644 index 0000000000000000000000000000000000000000..2b330e8e0a8a3f64753cfb7a2e2362222439312d --- /dev/null +++ b/ISSUES.md @@ -0,0 +1,9 @@ +If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance +issue or a feature request or a build issue or a documentation issue (for small +doc fixes please send a PR instead). 2. Make sure the Issue Template is filled +out. 3. The issue should be related to the repo it is created in. + +**Here's why we have this policy:** We want to focus on the work that benefits +the whole community, e.g., fixing bugs and adding features. Individual support +should be seeked on StackOverflow or other non-GitHub channels. It helps us to +address bugs and feature requests in a timely manner. diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 52faed9297cfcaf8c93bb9c79686c9258a53c560..b3d84ad8c948df9459a8e8afb029785d6f6ad335 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -29,9 +29,11 @@ You can collect some of this information using our environment capture script: https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh -You can obtain the TensorFlow version with +You can obtain the TensorFlow version with: +```bash python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +``` ### Describe the problem Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request. diff --git a/README.md b/README.md index 57efb876c9afaf9fe76c4ced4e6a1572e9241edf..044174947a094d43a51f7140dd40ec0f17801d40 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,14 @@ |-----------------| | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | -**TensorFlow** is an open source software library for numerical computation using -data flow graphs. The graph nodes represent mathematical operations, while +**TensorFlow** is an open source software library for numerical computation +using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture enables you to deploy computation to one -or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. +between them. This flexible architecture enables you to deploy computation to +one or more CPUs or GPUs in a desktop, server, or mobile device without +rewriting code. TensorFlow also includes +[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization +toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -29,7 +31,21 @@ subscribing to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation -*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.* + +To install the current release for CPU-only: + +``` +pip install tensorflow +``` + +Use the GPU package for CUDA-enabled GPU cards: + +``` +pip install tensorflow-gpu +``` + +*See [Installing TensorFlow](https://www.tensorflow.org/install) for detailed +instructions, and how to build from source.* People who are a little more adventurous can also try our nightly binaries: @@ -65,9 +81,10 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** +tracking requests and bugs, so please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) +for general questions and discussion, and please direct specific questions to +[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** The TensorFlow project strives to abide by generally accepted best practices in open-source software development: @@ -93,25 +110,27 @@ The TensorFlow project strives to abide by generally accepted best practices in ### Community Supported Builds -| Build Type | Status | Artifacts | -| --- | --- | --- | -| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | -| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | -| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | -| **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | - +Build Type | Status | Artifacts +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA +**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA +**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) +**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.4
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)
[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)
[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)
[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl) ## For more information -* [TensorFlow Website](https://www.tensorflow.org) -* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) -* [TensorFlow Twitter](https://twitter.com/tensorflow) -* [TensorFlow Blog](https://medium.com/tensorflow) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) -* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) -* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) + +* [TensorFlow Website](https://www.tensorflow.org) +* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) +* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index 20e1d9217b7684e696d0abf427eef9ab9548d1b7..b13b071bd6cf4d3a260c8e248a67d23e1a688498 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,74 @@ +# Release 1.12.0 + +## Major Features and Improvements + +* Keras models can now be directly exported to the SavedModel + format(`tf.contrib.saved_model.save_keras_model()`) and used with Tensorflow + Serving. +* Keras models now support evaluating with a `tf.data.Dataset`. +* TensorFlow binaries are built with XLA support linked in by default. + +## Bug Fixes and Other Changes + +* tf.data: + * tf.data users can now represent, get, and set options of TensorFlow + input pipelines using `tf.data.Options()`, `tf.data.Dataset.options()`, + and `tf.data.Dataset.with_options()` respectively. + * New `tf.data.Dataset.reduce()` API allows users to reduce a finite + dataset to a single element using a user-provided reduce function. + * New `tf.data.Dataset.window()` API allows users to create finite windows + of input dataset; when combined with the `tf.data.Dataset.reduce()` API, + this allows users to implement customized batching. + * All C++ code moves to the `tensorflow::data` namespace. + * Add support for `num_parallel_calls` to `tf.data.Dataset.interleave`. +* `tf.contrib`: + * Remove `tf.contrib.linalg`. `tf.linalg` should be used instead. + * Replace any calls to `tf.contrib.get_signature_def_by_key(metagraph_def, + signature_def_key)` with + `meta_graph_def.signature_def[signature_def_key]`. Catching a ValueError + exception thrown by `tf.contrib.get_signature_def_by_key` should be + replaced by catching a KeyError exception. +* `tf.contrib.data` + * Deprecate, and replace by tf.data.experimental. +* Other: + * Instead of jemalloc, revert back to using system malloc since it + simplifies build and has comparable performance. + * Remove integer types from `tf.nn.softplus` and `tf.nn.softsign` OpDefs. + This is a bugfix; these ops were never meant to support integers. + * Allow subslicing Tensors with a single dimension. + * Add option to calculate string length in Unicode characters + * Add functionality to SubSlice a tensor. + * Add searchsorted (ie lower/upper_bound) op. + * Add model explainability to Boosted Trees. + * Support negative positions for tf.substr + * There was previously a bug in the bijector_impl where the + _reduce_jacobian_det_over_event does not handle scalar ILDJ + implementations properly. + * In tf eager execution, allow re-entering a GradientTape context + * Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, + then bazel will build TensorFlow API version 2.0. Note that TensorFlow + 2.0 is under active development and has no guarantees at this point. + * Add additional compression options to TfRecordWriter + * Performance improvements for regex full match operations. + * Replace tf.GraphKeys.VARIABLES with `tf.GraphKeys.GLOBAL_VARIABLES` + * Remove unused dynamic learning rate support. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +(David) Siu-Kei Muk, Ag Ramesh, Anton Dmitriev, Artem Sobolev, Avijit-Nervana, +Bairen Yi, Bruno Goncalves, By Shen, candy.dc, Cheng Chen, Clayne Robison, +coder3101, Dao Zhang, Elms, Fei Hu, feiquan, Geoffrey Irving, Guozhong Zhuang, +hellcom, Hoeseong Kim, imsheridan, Jason Furmanek, Jason Zaman, Jenny Sahng, +jiefangxuanyan, Johannes Bannhofer, Jonathan Homer, Koan-Sin Tan, kouml, Loo +Rong Jie, Lukas Geiger, manipopopo, Ming Li, Moritz KröGer, Naurril, Niranjan +Hasabnis, Pan Daoxin, Peng Yu, pengwa, rasmi, Roger Xin, Roland Fernandez, Sami +Kama, Samuel Matzek, Sangjung Woo, Sergei Lebedev, Sergii Khomenko, shaohua, +Shaohua Zhang, Shujian2015, Sunitha Kambhampati, tomguluson92, ViníCius Camargo, +wangsiyu, weidankong, Wen-Heng (Jack) Chung, William D. Irons, Xin Jin, Yan +Facai (颜发才), Yanbo Liang, Yash Katariya, Yong Tang, 在原佐为 + # Release 1.11.0 ## Major Features and Improvements @@ -20,51 +91,84 @@ ## Bug Fixes and Other Changes -* C++: - * Changed the signature of SessionFactory::NewSession so that it can return a meaningful error message on failure. -* tf.data: - * Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove `num_parallel_parser_calls` argument from `tf.contrib.data.make_csv_dataset()`. - * `tf.data.Dataset.list_files()` raises an exception at initialization time if the argument matches no files. - * Renamed BigTable class to BigtableTable for clarity - * Document use of the Cloud Bigtable API - * Adding `tf.contrib.data.reduce_dataset` which can be used to reduce a dataset to a single element. - * Generalization of `tf.contrib.data.sliding_window_batch`. -* INC: - * Runtime improvements to triangular solve. -* `tf.contrib`: - * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` and `tf.keras.layers.LocallyConnected1D`. The new mode (`implementation=2`) performs forward pass as a single dense matrix multiplication, allowing dramatic speedups in certain scenarios (but worse performance in others - see docstring). The option also allows to use `padding=same`. - * Add documentation clarifying the differences between tf.fill and tf.constant. - * Add experimental IndexedDatasets. - * Add selective registration target using the lite proto runtime. - * Add simple Tensor and DataType classes to TensorFlow Lite Java - * Add support for bitcasting to/from uint32 and uint64. - * Added a subclass of Estimator that can be created from a SavedModel (SavedModelEstimator). - * Adds leaf index modes as an argument. - * Allow a different output shape from the input in tf.contrib.image.transform. - * Change the state_size order of the StackedRNNCell to be natural order. To keep the existing behavior, user can add reverse_state_order=True when constructing the StackedRNNCells. - * Deprecate self.test_session() in favor of self.session() or self.cached_session(). - * Directly import tensor.proto.h (the transitive import will be removed from tensor.h soon) - * Estimator.train() now supports tf.contrib.summary.\* summaries out of the box; each call to .train() will now create a separate tfevents file rather than re-using a shared one. - * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term should not end up in the accumulator. - * Fix toco compilation/execution on Windows - * GoogleZoneProvider class added to detect which Google Cloud Engine zone tensorflow is running in. - * It is now safe to call any of the C API's TF_Delete\* functions on nullptr - * Log some errors on Android to logcat - * Match FakeQuant numerics in TFLite to improve accuracy of TFLite quantized inference models. - * Optional bucket location check for the GCS Filesystem. - * Performance enhancements for StringSplitOp & StringSplitV2Op. - * Performance improvements for regex replace operations. - * TFRecordWriter now raises an error if .write() fails. - * TPU: More helpful error messages in TPUClusterResolvers. - * The legacy_init_op argument to SavedModelBuilder methods for adding MetaGraphs has been deprecated. Please use the equivalent main_op argument instead. As part of this, we now explicitly check for a single main_op or legacy_init_op at the time of SavedModel building, whereas the check on main_op was previously only done at load time. - * The protocol used for Estimator training is now configurable in RunConfig. - * Triangular solve performance improvements. - * Unify RNN cell interface between TF and Keras. Add new get_initial_state() to Keras and TF RNN cell, which will use to replace the existing zero_state() method. - * Update initialization of variables in Keras. - * Updates to "constrained_optimization" in tensorflow/contrib. - * boosted trees: adding pruning mode - * tf.train.Checkpoint does not delete old checkpoints by default. - * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow adjustment of this upper limit. +* C++: + * Changed the signature of SessionFactory::NewSession so that it can + return a meaningful error message on failure. +* tf.data: + * Remove `num_parallel_parser_calls` argument from + `tf.contrib.data.make_csv_dataset()`. [tf.data] Remove + `num_parallel_parser_calls` argument from + `tf.contrib.data.make_csv_dataset()`. + * `tf.data.Dataset.list_files()` raises an exception at initialization + time if the argument matches no files. + * Renamed BigTable class to BigtableTable for clarity + * Document use of the Cloud Bigtable API + * Add `tf.contrib.data.reduce_dataset` which can be used to reduce a + dataset to a single element. + * Generalization of `tf.contrib.data.sliding_window_batch`. +* INC: + * Runtime improvements to triangular solve. +* `tf.contrib`: + * Add an `implementation` argument to `tf.keras.layers.LocallyConnected2D` + and `tf.keras.layers.LocallyConnected1D`. The new mode + (`implementation=2`) performs forward pass as a single dense matrix + multiplication, allowing dramatic speedups in certain scenarios (but + worse performance in others - see docstring). The option also allows to + use `padding=same`. + * Add documentation clarifying the differences between tf.fill and + tf.constant. + * Add experimental IndexedDatasets. + * Add selective registration target using the lite proto runtime. + * Add simple Tensor and DataType classes to TensorFlow Lite Java + * Add support for bitcasting to/from uint32 and uint64. + * Added a subclass of Estimator that can be created from a SavedModel + (SavedModelEstimator). + * Adds leaf index modes as an argument. + * Allow a different output shape from the input in + tf.contrib.image.transform. + * Change the state_size order of the StackedRNNCell to be natural order. + To keep the existing behavior, user can add reverse_state_order=True + when constructing the StackedRNNCells. + * Deprecate self.test_session() in favor of self.session() or + self.cached_session(). + * Directly import tensor.proto.h (the transitive import will be removed + from tensor.h soon) + * Estimator.train() now supports tf.contrib.summary.\* summaries out of + the box; each call to .train() will now create a separate tfevents file + rather than re-using a shared one. + * Fix FTRL L2-shrinkage behavior: the gradient from the L2 shrinkage term + should not end up in the accumulator. + * Fix toco compilation/execution on Windows + * GoogleZoneProvider class added to detect which Google Cloud Engine zone + tensorflow is running in. + * It is now safe to call any of the C API's TF_Delete\* functions on + nullptr + * Log some errors on Android to logcat + * Match FakeQuant numerics in TFLite to improve accuracy of TFLite + quantized inference models. + * Optional bucket location check for the GCS Filesystem. + * Performance enhancements for StringSplitOp & StringSplitV2Op. + * Performance improvements for regex replace operations. + * TFRecordWriter now raises an error if .write() fails. + * TPU: More helpful error messages in TPUClusterResolvers. + * The legacy_init_op argument to SavedModelBuilder methods for adding + MetaGraphs has been deprecated. Please use the equivalent main_op + argument instead. As part of this, we now explicitly check for a single + main_op or legacy_init_op at the time of SavedModel building, whereas + the check on main_op was previously only done at load time. + * The protocol used for Estimator training is now configurable in + RunConfig. + * Triangular solve performance improvements. + * Unify RNN cell interface between TF and Keras. Add new + get_initial_state() to Keras and TF RNN cell, which will use to replace + the existing zero_state() method. + * Update initialization of variables in Keras. + * Updates to "constrained_optimization" in tensorflow/contrib. + * boosted trees: adding pruning mode + * tf.train.Checkpoint does not delete old checkpoints by default. + * tfdbg: Limit the total disk space occupied by dumped tensor data to 100 + GBytes. Add environment variable `TFDBG_DISK_BYTES_LIMIT` to allow + adjustment of this upper limit. ## Thanks to our Contributors @@ -154,8 +258,8 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A * Update `tf.keras` to the Keras 2.1.6 API. * Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). * Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). -* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite) - for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md) +* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite) + for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md) has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again included in the standard `pip` installation. * Improved data-loading and text processing with: @@ -458,7 +562,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 ## Major Features And Improvements * [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) preview version is now available. -* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) +* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/lite) dev preview is now available. * CUDA 9.0 and cuDNN 7 support. * Accelerated Linear Algebra (XLA): @@ -805,7 +909,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/ * Adds tf.contrib.nn.rank_sampled_softmax_loss, a sampled-softmax variant that can improve rank loss. * `tf.contrib.metrics`.{streaming_covariance,streaming_pearson_correlation} modified to return nan when they have seen less or equal to 1 unit of weight. * Adds time series models to contrib. See contrib/timeseries/README.md for details. -* Adds FULLY_CONNECTED Op to tensorflow/contrib/lite/schema.fbs +* Adds FULLY_CONNECTED Op to tensorflow/lite/schema.fbs ## Known Issues * Tensorflow_gpu compilation fails with Bazel 0.5.3. diff --git a/WORKSPACE b/WORKSPACE index 17961829a605c2d1f2d2ba86a7c30c47618c139b..7cc08e0164a202581ad7ebbe107a9e19410e70e4 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,5 +1,7 @@ workspace(name = "org_tensorflow") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + http_archive( name = "io_bazel_rules_closure", sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", @@ -14,6 +16,33 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +http_archive( + name = "base_images_docker", + sha256 = "e2b1b7254270bb7605e814a9dbf6d1e4ae04a11136ff1714fbfdabe3f87f7cf9", + strip_prefix = "base-images-docker-12801524f867e657fbb5d1a74f31618aff181ac6", + urls = ["https://github.com/GoogleCloudPlatform/base-images-docker/archive/12801524f867e657fbb5d1a74f31618aff181ac6.tar.gz"], +) + +http_archive( + name = "bazel_toolchains", + sha256 = "15b5858b1b5541ec44df31b94c3b8672815b31d71215a98398761ea9f4c4eedb", + strip_prefix = "bazel-toolchains-6200b238c9c2d137c0d9a7262c80cc71d98e692b", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/archive/6200b238c9c2d137c0d9a7262c80cc71d98e692b.tar.gz", + ], +) + +http_archive( + name = "io_bazel_rules_docker", + sha256 = "29d109605e0d6f9c892584f07275b8c9260803bf0c6fcb7de2623b2bedc910bd", + strip_prefix = "rules_docker-0.5.1", + urls = ["https://github.com/bazelbuild/rules_docker/archive/v0.5.1.tar.gz"], +) + +load("//third_party/toolchains/preconfig/generate:workspace.bzl", "remote_config_workspace") + +remote_config_workspace() + # We must check the bazel version before trying to parse any other BUILD # files, in case the parsing of those build files depends on the bazel # version we require here. @@ -30,9 +59,9 @@ android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() -new_http_archive( +http_archive( name = "inception_v1", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", @@ -40,9 +69,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_ssd", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", @@ -50,9 +79,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "mobile_multibox", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", @@ -60,9 +89,9 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "stylize", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", @@ -70,12 +99,13 @@ new_http_archive( ], ) -new_http_archive( +http_archive( name = "speech_commands", - build_file = "models.BUILD", + build_file = "//:models.BUILD", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", urls = [ "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", "http://download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) + diff --git a/configure.py b/configure.py index a88fdb3555531a13300a0aabe36e2cc65a969daa..6c905a0be3d685b5921dfbc5bddfbe6471a82625 100644 --- a/configure.py +++ b/configure.py @@ -35,7 +35,6 @@ except ImportError: _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' -_DEFAULT_NCCL_VERSION = '2.2' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -44,7 +43,7 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16, 17, 18] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -239,6 +238,13 @@ def setup_python(environ_cp): write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path + # If choosen python_lib_path is from a path specified in the PYTHONPATH + # variable, need to tell bazel to include PYTHONPATH + if environ_cp.get('PYTHONPATH'): + python_paths = environ_cp.get('PYTHONPATH').split(':') + if python_lib_path in python_paths: + write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH')) + # Write tools/python_bin_path.sh with open( os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), @@ -384,7 +390,9 @@ def set_build_var(environ_cp, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var if var == '1': - write_to_bazelrc('build --define %s=true' % option_name) + write_to_bazelrc( + 'build:%s --define %s=true' % (bazel_config_name, option_name)) + write_to_bazelrc('build --config=%s' % bazel_config_name) elif bazel_config_name is not None: # TODO(mikecase): Migrate all users of configure.py to use --config Bazel # options and not to set build configs through environment variables. @@ -444,11 +452,12 @@ def convert_version_to_int(version): return int(version_str) -def check_bazel_version(min_version): - """Check installed bazel version is at least min_version. +def check_bazel_version(min_version, max_version): + """Check installed bazel version is between min_version and max_version. Args: min_version: string for minimum bazel version. + max_version: string for maximum bazel version. Returns: The bazel version detected. @@ -466,6 +475,7 @@ def check_bazel_version(min_version): min_version_int = convert_version_to_int(min_version) curr_version_int = convert_version_to_int(curr_version) + max_version_int = convert_version_to_int(max_version) # Check if current bazel version can be detected properly. if not curr_version_int: @@ -479,6 +489,10 @@ def check_bazel_version(min_version): print('Please upgrade your bazel installation to version %s or higher to ' 'build TensorFlow!' % min_version) sys.exit(0) + if curr_version_int > max_version_int: + print('Please downgrade your bazel installation to version %s or lower to ' + 'build TensorFlow!' % max_version) + sys.exit(0) return curr_version @@ -496,7 +510,7 @@ def set_cc_opt_flags(environ_cp): elif is_windows(): default_cc_opt_flags = '/arch:AVX' else: - default_cc_opt_flags = '-march=native' + default_cc_opt_flags = '-march=native -Wno-sign-compare' question = ('Please specify optimization flags to use during compilation when' ' bazel option "--config=opt" is specified [Default is %s]: ' ) % default_cc_opt_flags @@ -858,7 +872,7 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_paths_full = [ os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths ] - if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): + if any(os.path.exists(x) for x in cuda_toolkit_paths_full): break # Reset and retry @@ -1109,18 +1123,17 @@ def set_tf_nccl_install_path(environ_cp): raise ValueError('Currently NCCL is only supported on Linux platforms.') ask_nccl_version = ( - 'Please specify the NCCL version you want to use. If NCCL %s is not ' - 'installed, then you can use version 1.3 that can be fetched ' - 'automatically but it may have worse performance with multiple GPUs. ' - '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION) + 'Please specify the locally installed NCCL version you want to use. ' + '[Default is to use https://github.com/nvidia/nccl]: ') for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_nccl_version = get_from_env_or_user_or_default( - environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) - tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '') - if tf_nccl_version == '1': - break # No need to get install path, NCCL 1 is a GitHub repo. + if not tf_nccl_version: + break # No need to get install path, building the open source code. + + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) # Look with ldconfig first if we can find the library in paths # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding @@ -1182,6 +1195,7 @@ def set_tf_nccl_install_path(environ_cp): if is_windows() or is_cygwin(): nccl_install_path = cygpath(nccl_install_path) + nccl_lib_path = '' if is_windows(): nccl_lib_path = 'lib/x64/nccl.lib' elif is_linux(): @@ -1232,7 +1246,6 @@ def set_tf_nccl_install_path(environ_cp): environ_cp['TF_NCCL_VERSION'] = tf_nccl_version write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) - def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1418,11 +1431,16 @@ def set_mpi_home(environ_cp): def valid_mpi_path(mpi_home): exists = ( os.path.exists(os.path.join(mpi_home, 'include')) and - os.path.exists(os.path.join(mpi_home, 'lib'))) + (os.path.exists(os.path.join(mpi_home, 'lib')) or + os.path.exists(os.path.join(mpi_home, 'lib64')) or + os.path.exists(os.path.join(mpi_home, 'lib32')))) if not exists: - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) + print( + 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found' + % (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')), + os.path.exists(os.path.join(mpi_home, 'lib64')), + os.path.exists(os.path.join(mpi_home, 'lib32')))) return exists _ = prompt_loop_or_load_from_env( @@ -1463,8 +1481,17 @@ def set_other_mpi_vars(environ_cp): if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): symlink_force( os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so') + elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')): + symlink_force( + os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so') + else: - raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) + raise ValueError( + 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % + mpi_home, mpi_home, mpi_home) def set_system_libs_flag(environ_cp): @@ -1499,14 +1526,6 @@ def set_windows_build_flags(environ_cp): # TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0 # Short object file path will be enabled by default. write_to_bazelrc('build --experimental_shortened_obj_file_path=true') - # When building zip file for some py_binary and py_test targets, don't - # include its dependencies. This is for: - # 1. Running python tests against the system installed TF pip package. - # 2. Avoiding redundant files in - # //tensorflow/tools/pip_package:simple_console_windows, - # which is a py_binary used during creating TF pip package. - # See https://github.com/tensorflow/tensorflow/issues/22390 - write_to_bazelrc('build --define=no_tensorflow_py_deps=true') if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1546,9 +1565,12 @@ def main(): # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.15.0') + check_bazel_version('0.15.0', '0.20.0') reset_tf_configure_bazelrc() + # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later + write_to_bazelrc('import %workspace%/tools/bazel.rc') + cleanup_makefile() setup_python(environ_cp) @@ -1561,13 +1583,11 @@ def main(): # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' - environ_cp['TF_ENABLE_XLA'] = '0' environ_cp['TF_NEED_MPI'] = '0' environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' - environ_cp['TF_ENABLE_XLA'] = '0' # The numpy package on ppc64le uses OpenBLAS which has multi-threading # issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at @@ -1576,10 +1596,9 @@ def main(): if is_ppc64le(): write_action_env_to_bazelrc('OMP_NUM_THREADS', 1) - set_build_var(environ_cp, 'TF_NEED_IGNITE', 'Apache Ignite', - 'with_ignite_support', True, 'ignite') + xla_enabled_by_default = is_linux() set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', - True, 'xla') + xla_enabled_by_default, 'xla') set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1671,18 +1690,24 @@ def main(): create_android_ndk_rule(environ_cp) create_android_sdk_rule(environ_cp) - # On Windows, we don't have MKL support and the build is always monolithic. - # So no need to print the following message. - # TODO(pcloudy): remove the following if check when they make sense on Windows - if not is_windows(): - print('Preconfigured Bazel build configs. You can use any of the below by ' - 'adding "--config=<>" to your build command. See .bazelrc for more ' - 'details.') - config_info_line('mkl', 'Build with MKL support.') - config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('gdr', 'Build with GDR support.') - config_info_line('verbs', 'Build with libverbs support.') - config_info_line('ngraph', 'Build with Intel nGraph support.') + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See .bazelrc for more ' + 'details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') + config_info_line('gdr', 'Build with GDR support.') + config_info_line('verbs', 'Build with libverbs support.') + config_info_line('ngraph', 'Build with Intel nGraph support.') + config_info_line('dynamic_kernels', + '(Experimental) Build kernels into separate shared objects.') + + print('Preconfigured Bazel build configs to DISABLE default on features:') + config_info_line('noaws', 'Disable AWS S3 filesystem support.') + config_info_line('nogcp', 'Disable GCP support.') + config_info_line('nohdfs', 'Disable HDFS support.') + config_info_line('noignite', 'Disable Apacha Ignite support.') + config_info_line('nokafka', 'Disable Apache Kafka support.') + config_info_line('nonccl', 'Disable NVIDIA NCCL support.') if __name__ == '__main__': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9b62a504525d5377d4836e92bdf0e46f7fc3ef38..fd4b94202aad24a82abef8abd16431f61a8326f0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = ( TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) ) +# @unused +TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = ( + TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -209,12 +214,46 @@ config_setting( visibility = ["//visibility:public"], ) +# Features that are default ON are handled differently below. +# +config_setting( + name = "no_aws_support", + define_values = {"no_aws_support": "true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_gcp_support", + define_values = {"no_gcp_support": "true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_hdfs_support", + define_values = {"no_hdfs_support": "true"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "no_ignite_support", + define_values = {"no_ignite_support": "true"}, + visibility = ["//visibility:public"], +) + config_setting( - name = "with_ignite_support", - define_values = {"with_ignite_support": "true"}, + name = "no_kafka_support", + define_values = {"no_kafka_support": "true"}, visibility = ["//visibility:public"], ) +config_setting( + name = "no_nccl_support", + define_values = {"no_nccl_support": "true"}, + visibility = ["//visibility:public"], +) + +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. config_setting( name = "with_cuda_support_windows_override", define_values = {"using_cuda_nvcc": "true"}, @@ -322,8 +361,9 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", - "//tensorflow_estimator/...", + "//tensorflow_estimator/contrib/...", "//tensorflow_fold/llgtm/...", + "//tensorflow_text/...", "//third_party/py/tensor2tensor/...", ], ) @@ -525,35 +565,45 @@ genrule( }), outs = ["__init__.py"], cmd = select({ - "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", - "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)", }), ) gen_api_init_files( name = "tf_python_api_gen_v1", - srcs = ["api_template.__init__.py"], + srcs = [ + "api_template_v1.__init__.py", + "compat_template_v1.__init__.py", + ], api_version = 1, + compat_api_versions = [1], + compat_init_templates = ["compat_template_v1.__init__.py"], output_dir = "_api/v1/", - output_files = TENSORFLOW_API_INIT_FILES_V1, + output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT, output_package = "tensorflow._api.v1", - root_init_template = "api_template.__init__.py", + root_file_name = "v1.py", + root_init_template = "api_template_v1.__init__.py", ) gen_api_init_files( name = "tf_python_api_gen_v2", - srcs = ["api_template.__init__.py"], + srcs = [ + "api_template.__init__.py", + "compat_template_v1.__init__.py", + ], api_version = 2, compat_api_versions = [1], + compat_init_templates = ["compat_template_v1.__init__.py"], output_dir = "_api/v2/", output_files = TENSORFLOW_API_INIT_FILES_V2, output_package = "tensorflow._api.v2", + root_file_name = "v2.py", root_init_template = "api_template.__init__.py", ) py_library( name = "tensorflow_py", - srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 2de740e145f93b151faf5c987808dbdf73fb4fd7..d81cf067eb07e88e2b8a86cf5643674235eb3f3b 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -21,41 +21,24 @@ from __future__ import print_function as _print_function import os as _os # pylint: disable=g-bad-import-order -from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import - -try: - # Add `estimator` attribute to allow access to estimator APIs via - # "tf.estimator..." - from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top - - # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." - # style imports. - from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [_os.path.dirname(estimator_api.__file__)] - del estimator_api -except (ImportError, AttributeError): - print('tf.estimator package not installed.') +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) # API IMPORTS PLACEHOLDER -from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top -contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') -del LazyLoader -# The templated code that replaces the placeholder above sometimes -# sets the __all__ variable. If it does, we have to be sure to add -# "contrib". -if '__all__' in vars(): - vars()['__all__'].append('contrib') - -from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable - # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. -_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +# We're using bitwise, but there's nothing special about that. +_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable if _tf_api_dir not in __path__: __path__.append(_tf_api_dir) +# Enable TF2 behaviors +from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top +_compat.enable_v2_behavior() + # These symbols appear because we import the python package which # in turn imports from tensorflow.core and tensorflow.python. They # must come from this module. So python adds these symbols for the @@ -66,7 +49,14 @@ try: del core except NameError: # Don't fail if these modules are not available. - # For e.g. we are using this file for compat.v1 module as well and - # 'python', 'core' directories are not under compat/v1. + # For e.g. this file will be originally placed under tensorflow/_api/v1 which + # does not have 'python', 'core' directories. Then, it will be copied + # to tensorflow/ which does have these two directories. + pass +# Similarly for compiler. Do it separately to make sure we do this even if the +# others don't exist. +try: + del compiler +except NameError: pass # pylint: enable=undefined-variable diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65bdb6cb1b5e6fb0656a12b932d767aeacfccd29 --- /dev/null +++ b/tensorflow/api_template_v1.__init__.py @@ -0,0 +1,72 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import + +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader +# The templated code that replaces the placeholder above sometimes +# sets the __all__ variable. If it does, we have to be sure to add +# "contrib". +if '__all__' in vars(): + vars()['__all__'].append('contrib') + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. this file will be originally placed under tensorflow/_api/v1 which + # does not have 'python', 'core' directories. Then, it will be copied + # to tensorflow/ which does have these two directories. + pass +# Similarly for compiler. Do it separately to make sure we do this even if the +# others don't exist. +try: + del compiler +except NameError: + pass +# pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 17e2e292eb19029d279bc12a8328edadf96f1bb8..25df970ecab0757f23465ab19e7f45de0c759458 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -6,11 +6,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", - "tf_cuda_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library", + "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # ----------------------------------------------------------------------------- # Public targets @@ -59,6 +60,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:op_gen_lib", + "//tensorflow/core/distributed_runtime:server_lib", ], }), ) @@ -94,6 +96,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:server_lib", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -118,13 +121,15 @@ tf_cuda_library( ":c_api", ":c_api_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:attr_builder", ], ) @@ -170,6 +175,60 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "env", + srcs = [ + "env.cc", + ], + hdrs = [ + "env.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + "//conditions:default": [ + ":c_api", + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:platform_env", + "//tensorflow/core:lib", + ], + }) + [":c_api_internal"], +) + +tf_cuda_library( + name = "kernels", + srcs = [ + "kernels.cc", + ], + hdrs = [ + "kernels.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = select({ + "//tensorflow:android": [ + ":c_api", + ":c_api_internal", + ":tf_status_helper", + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + ":c_api", + ":c_api_internal", + ":tf_status_helper", + "//tensorflow/core:framework", + ], + }), +) + # ----------------------------------------------------------------------------- # Tests @@ -197,14 +256,18 @@ tf_cuda_cc_test( size = "small", srcs = ["c_api_test.cc"], data = [ - ":test_op.so", + ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], + kernels = [":test_op_kernel"], linkopts = select({ "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), - tags = ["noasan"], + tags = [ + "no_oss", # http://b/119522529 + "noasan", + ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -215,6 +278,7 @@ tf_cuda_cc_test( "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/compiler/jit", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", @@ -232,7 +296,7 @@ tf_cuda_cc_test( tf_cc_test( name = "c_api_experimental_test", - size = "small", + size = "medium", srcs = ["c_api_experimental_test.cc"], data = ["testdata/tf_record"], linkopts = select({ @@ -243,8 +307,11 @@ tf_cc_test( # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), deps = [ + ":c_api", ":c_api_experimental", ":c_test_util", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -281,8 +348,63 @@ tf_cc_test( ) tf_custom_op_library( - name = "test_op.so", + name = "test_op1.so", + srcs = ["test_op1.cc"], +) + +tf_kernel_library( + name = "test_op_kernel", srcs = ["test_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "env_test", + size = "small", + srcs = ["env_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cuda_cc_test( + name = "kernels_test", + size = "small", + srcs = ["kernels_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + tags = ["noasan"], + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + ":kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/c/README.md b/tensorflow/c/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b386998ceaf3e91daba04125fe83e2f3bdd508e5 --- /dev/null +++ b/tensorflow/c/README.md @@ -0,0 +1,7 @@ +# TensorFlow C API + +- See [www.tensorflow.org/install/lang_c](https://www.tensorflow.org/install/lang_c) +- Nightly builds: + - [Linux CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-linux-x86_64.tar.gz) + - [Linux GPU](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-gpu-linux-x86_64.tar.gz) + - [MacOS CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-darwin-x86_64.tar.gz) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 79811ceae57e0bddeb2a6f32bad7003e14e23422..94d18eb8b04e3534be547aca5cfbb32da40ffbf6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -136,16 +136,22 @@ const char* TF_Message(const TF_Status* s) { namespace { class TF_ManagedBuffer : public TensorBuffer { public: - void* data_; - size_t len_; - void (*deallocator_)(void* data, size_t len, void* arg); - void* deallocator_arg_; + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; ~TF_ManagedBuffer() override { - (*deallocator_)(data_, len_, deallocator_arg_); + (*deallocator_)(data(), len_, deallocator_arg_); } - void* data() const override { return data_; } size_t size() const override { return len_; } TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { @@ -199,8 +205,7 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, dimvec[i] = static_cast(dims[i]); } - TF_ManagedBuffer* buf = new TF_ManagedBuffer; - buf->len_ = len; + TF_ManagedBuffer* buf = nullptr; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != @@ -212,17 +217,15 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf->data_ = allocate_tensor("TF_NewTensor", len); - std::memcpy(buf->data_, data, len); - buf->deallocator_ = deallocate_buffer; - buf->deallocator_arg_ = nullptr; + buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, + deallocate_buffer, nullptr); + std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); } else { - buf->data_ = data; - buf->deallocator_ = deallocator; - buf->deallocator_arg_ = deallocator_arg; + buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; size_t elem_size = TF_DataTypeSize(dtype); if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { @@ -477,9 +480,9 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { CHECK_EQ(nelems, 0); static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast(dims.data()), - shape.dims(), reinterpret_cast(&empty), 0, - [](void*, size_t, void*) {}, nullptr); + return TF_NewTensor( + dtype, reinterpret_cast(dims.data()), shape.dims(), + reinterpret_cast(&empty), 0, [](void*, size_t, void*) {}, nullptr); } // Non-static for testing. @@ -1592,18 +1595,20 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, break; \ } - LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0; - for (int i = 0; i < attr->list().s_size(); - ++i) { metadata.total_size += attr->list().s(i).size(); }); + LIST_CASE( + s, TF_ATTR_STRING, metadata.total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { metadata.total_size += attr->list().s(i).size(); }); LIST_CASE(i, TF_ATTR_INT); LIST_CASE(f, TF_ATTR_FLOAT); LIST_CASE(b, TF_ATTR_BOOL); LIST_CASE(type, TF_ATTR_TYPE); - LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0; - for (int i = 0; i < attr->list().shape_size(); ++i) { - const auto& s = attr->list().shape(i); - metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); - }); + LIST_CASE( + shape, TF_ATTR_SHAPE, metadata.total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + metadata.total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); LIST_CASE(tensor, TF_ATTR_TENSOR); LIST_CASE(tensor, TF_ATTR_FUNC); #undef LIST_CASE @@ -1942,6 +1947,10 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, const char* prefix) { opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts, + const char* device) { + opts->opts.default_device = device; +} void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, unsigned char uniquify_names) { @@ -2770,6 +2779,9 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, } string name_str(name, name_len); const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + if (api_def == nullptr) { + return nullptr; + } TF_Buffer* ret = TF_NewBuffer(); status->status = MessageToBuffer(*api_def, ret); @@ -2803,4 +2815,71 @@ TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) { } return ret; } + +// TF_Server functions ---------------------------------------------- + +#ifndef __ANDROID__ +TF_Server::TF_Server(std::unique_ptr server) + : target(server->target()), server(std::move(server)) {} +#endif // __ANDROID__ + +TF_Server* TF_NewServer(const void* proto, size_t proto_len, + TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); + return nullptr; +#else + tensorflow::ServerDef server_def; + if (!server_def.ParseFromArray(proto, static_cast(proto_len))) { + status->status = InvalidArgument( + "Could not parse provided bytes into a ServerDef protocol buffer"); + return nullptr; + } + + std::unique_ptr out_server; + status->status = tensorflow::NewServer(server_def, &out_server); + if (!status->status.ok()) return nullptr; + + return new TF_Server(std::move(out_server)); +#endif +} + +void TF_ServerStart(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Start(); +#endif +} + +void TF_ServerStop(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Stop(); +#endif +} + +void TF_ServerJoin(TF_Server* server, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Server functionality is not supported in Android"); +#else + status->status = server->server->Join(); +#endif +} + +const char* TF_ServerTarget(TF_Server* server) { +#ifdef __ANDROID__ + return nullptr; +#else + return server->target.c_str(); +#endif +} + +void TF_DeleteServer(TF_Server* server) { delete server; } + } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 850f6ecd637d768bca99720e0add07680829e17a..c7abba85521fccec07983cd5ab4f94a8368d6181 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -91,7 +91,7 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -TF_CAPI_EXPORT extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(void); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. @@ -157,7 +157,7 @@ typedef enum TF_Code { typedef struct TF_Status TF_Status; // Return a new status object. -TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); // Delete a previously created status object. TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); @@ -196,7 +196,7 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); // Useful for passing *out* a protobuf. -TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); @@ -305,7 +305,7 @@ TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -338,7 +338,7 @@ TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. @@ -890,7 +890,8 @@ TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_ImportGraphDefOptions* opts); @@ -900,6 +901,12 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set the execution device for nodes in `graph_def`. +// Only applies to nodes where a device was not already explicitly specified. +// `device` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( + TF_ImportGraphDefOptions* opts, const char* device); + // Set whether to uniquify imported operation names. If true, imported operation // names will be modified if their name already exists in the graph. If false, // conflicting names will be treated as an error. Note that this option has no @@ -1605,7 +1612,7 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); // TF_ApiDefMap encapsulates a collection of API definitions for an operation. // @@ -1662,6 +1669,47 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( const char* name, TF_Status* status); +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index d4b78138e93624a7e41e917f8210281b500661bc..38e29aa74a90f4e85d1369b6928a5a58c531b2da 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -15,12 +15,18 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -50,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -70,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, // These XLA flags are needed to trigger XLA properly from C (more generally // non-Python) clients. If this API is called again with `enable` set to // false, it is safe to keep these flag values as is. - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + tensorflow::MarkForCompilationPassFlags* flags = + tensorflow::GetMarkForCompilationPassFlags(); flags->tf_xla_cpu_global_jit = true; flags->tf_xla_min_cluster_size = 1; } else { @@ -6524,7 +6530,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/cycle_length" + name: "ExperimentalParallelInterleaveDataset/cycle_length" op: "Const" attr { key: "dtype" @@ -6545,7 +6551,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/block_length" + name: "ExperimentalParallelInterleaveDataset/block_length" op: "Const" attr { key: "dtype" @@ -6566,7 +6572,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/sloppy" + name: "ExperimentalParallelInterleaveDataset/sloppy" op: "Const" attr { key: "dtype" @@ -6587,7 +6593,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/buffer_output_elements" + name: "ExperimentalParallelInterleaveDataset/buffer_output_elements" op: "Const" attr { key: "dtype" @@ -6608,7 +6614,7 @@ library { } } node_def { - name: "ParallelInterleaveDataset/prefetch_input_elements" + name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements" op: "Const" attr { key: "dtype" @@ -6629,14 +6635,14 @@ library { } } node_def { - name: "ParallelInterleaveDataset" - op: "ParallelInterleaveDataset" + name: "ExperimentalParallelInterleaveDataset" + op: "ExperimentalParallelInterleaveDataset" input: "RepeatDataset:handle:0" - input: "ParallelInterleaveDataset/cycle_length:output:0" - input: "ParallelInterleaveDataset/block_length:output:0" - input: "ParallelInterleaveDataset/sloppy:output:0" - input: "ParallelInterleaveDataset/buffer_output_elements:output:0" - input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0" + input: "ExperimentalParallelInterleaveDataset/block_length:output:0" + input: "ExperimentalParallelInterleaveDataset/sloppy:output:0" + input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0" attr { key: "Targuments" value { @@ -6736,7 +6742,7 @@ library { node_def { name: "ShuffleDataset_2" op: "ShuffleDataset" - input: "ParallelInterleaveDataset:handle:0" + input: "ExperimentalParallelInterleaveDataset:handle:0" input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0" @@ -8738,7 +8744,145 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) { TF_DeleteStatus(status); } -TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, - const char* errMsg) { +struct TFE_ExecuteOpNotification { + TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {} + tensorflow::Notification n; + std::unique_ptr thread; + std::unique_ptr status; +}; + +TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op, + TFE_TensorHandle** retvals, + int* num_retvals, + TF_Status* status) { + TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification; + + n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread( + tensorflow::ThreadOptions(), "ExecuteOpThread", + [op, retvals, num_retvals, n]() { + TFE_Execute(op, retvals, num_retvals, n->status.get()); + n->n.Notify(); + })); + + return n; +} + +void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status) { + if (notification == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification is a nullptr."); + + return; + } + if (notification->thread == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "Passed in notification didn't start a thread correctly. Cleaning up " + "this notification. Please re-execute the operation to get a new " + "notification."); + + delete notification; + return; + } + + notification->n.WaitForNotification(); + + status->status = notification->status->status; + + delete notification; +} + +void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } + +// This builder is used in the eager API to build a NodeDef. +struct TF_AttrBuilder : public tensorflow::AttrBuilder { + using tensorflow::AttrBuilder::AttrBuilder; + // The string buffers to make sure that any `attr_name` we pass into + // `builder->Set()` will outlive the subsequent + // `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`. + std::set attr_names; +}; + +TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) { + return new TF_AttrBuilder(op_name); +} + +void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; } + +void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, + TF_DataType value) { + auto iter = builder->attr_names.insert(attr_name).first; + builder->Set((*iter).c_str(), static_cast(value)); +} + +void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, + const TF_DataType* values, int num_values) { + auto iter = builder->attr_names.insert(attr_name).first; + builder->Set( + (*iter).c_str(), + tensorflow::gtl::ArraySlice( + reinterpret_cast(values), num_values)); +} + +void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, + const char* device_type, + TF_Status* status) { + status->status = tensorflow::FindKernelDef( + tensorflow::DeviceType(device_type), builder->BuildNodeDef(), + /* def = */ nullptr, /* kernel_class_name = */ nullptr); +} + +const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index, + TF_Status* status) { + const tensorflow::OpDef* op_def = nullptr; + status->status = + tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return nullptr; + + if (input_index >= op_def->input_arg_size() || input_index < 0) { + status->status = tensorflow::errors::InvalidArgument( + input_index, " out of range for ", op_name); + return nullptr; + } + + const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index]; + + if (input_arg.number_attr().empty()) { + status->status = tensorflow::errors::NotFound( + op_name, " does not have number_attr() defined."); + return nullptr; + } + + // The returned string is owned by OpRegistry, so liveness is not a concern. + return input_arg.number_attr().c_str(); +} + +int TF_OpIsStateful(const char* op_type, TF_Status* status) { + const tensorflow::OpRegistrationData* op_reg_data; + status->status = + tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data); + if (!status->status.ok()) { + return 0; + } + return op_reg_data->op_def.is_stateful(); +} + +void TF_InitMain(const char* usage, int* argc, char*** argv) { + tensorflow::port::InitMain(usage, argc, argv); +} + +int TF_PickUnusedPortOrDie() { + return tensorflow::internal::PickUnusedPortOrDie(); +} + +TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg, + void* data, size_t len) { + auto dtype = static_cast(dtype_arg); + DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype)); + + tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({})); + std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len); + return new TFE_TensorHandle(tensor, nullptr, nullptr); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index d98d532e32e891e21f5b7ba360c74c3256fb1947..3e3a485eb763b871b0551414c4ef04746b2ed9a3 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -180,9 +180,72 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TFE_TensorHandle* handle); +typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification; + +// Allows invoking a kernel asynchronously, and explicitly returns a +// notification that can be waited upon. This always executes the kernel in a +// new thread. +// 1. `retvals` and `num_retvals` can only be consumed after +// `TFE_ExecuteOp` returns successfully. They shouldn't be used +// if the return is unsuccessful +// 2. These new APIs cannot be used together with the TFE context level async +// support. +TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread( + TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, + TF_Status* status); + +// Waits to complete the op execution, and cleans up the notification. +// Errors reported by op execution are set in `status`. +TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete( + TFE_ExecuteOpNotification* notification, TF_Status* status); + TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); +// TF_NewAttrBuilder() returns an object that you can set attributes on as +// though it were an op. This allows querying properties of that op for +// type-checking purposes like if the op will run on a particular device type. +typedef struct TF_AttrBuilder TF_AttrBuilder; +TF_CAPI_EXPORT extern TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name); +TF_CAPI_EXPORT extern void TF_DeleteAttrBuilder(TF_AttrBuilder* builder); +TF_CAPI_EXPORT extern void TF_AttrBuilderSetType(TF_AttrBuilder* builder, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, + const char* attr_name, + const TF_DataType* values, + int num_values); + +// Checks the tensorflow::NodeDef built via the methods above to see if it can +// run on device_type. +TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( + TF_AttrBuilder* builder, const char* device_type, TF_Status* status); + +// For argument number input_index, fetch the corresponding number_attr that +// needs to be updated with the argument length of the input list. +// Returns nullptr if there is any problem like op_name is not found, or the +// argument does not support this attribute type. +TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( + const char* op_name, int input_index, TF_Status* status); + +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type, + TF_Status* status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); + +// Fast path method that makes constructing a single scalar tensor require less +// overhead and copies. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( + TF_DataType dtype, void* scalar, size_t len); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index c6effd39697e0397278770b53e98508074f99862..daa7701b7fe7e8ce757b6504329cf6434ad39778 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -162,5 +164,137 @@ protocol: "grpc" TF_DeleteStatus(status); } +TEST(CAPI_EXPERIMENTAL, IsStateful) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + int assign = TF_OpIsStateful("AssignAddVariableOp", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(assign, 1); + int id = TF_OpIsStateful("Identity", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + EXPECT_EQ(id, 0); +} + +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + TFE_Op* matmul_op = MatMulOp(ctx, m, m); + + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + + auto* r = + TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(r, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteOp(matmul_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + +// Perform a send/recv test. Recv blocks, so they need to be executed +// asynchronously. +TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4. + TFE_TensorHandle* m = TestMatrixTensorHandle(); + + // Build a send op. + TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(send_op, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + string tensor_name = "Tensor"; + TFE_OpSetAttrType(send_op, "T", TF_FLOAT); + TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + string send_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234); + string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(send_op, "client_terminated", true); + + // Build a recv op. + TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT); + TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(), + tensor_name.size()); + TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(), + send_device.size()); + TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234); + TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(), + recv_device.size()); + TFE_OpSetAttrBool(recv_op, "client_terminated", true); + + TFE_TensorHandle* send_retvals; + int send_num_retvals = 0; + auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals, + &send_num_retvals, status); + + TFE_TensorHandle* recv_retvals[1] = {nullptr}; + int recv_num_retvals = 1; + auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0], + &recv_num_retvals, status); + + TFE_ExecuteOpNotificationWaitAndDelete(send_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, product[0]); + EXPECT_EQ(2, product[1]); + EXPECT_EQ(3, product[2]); + EXPECT_EQ(4, product[3]); + + TFE_DeleteOp(send_op); + TFE_DeleteOp(recv_op); + TFE_DeleteTensorHandle(m); + + TFE_DeleteTensorHandle(recv_retvals[0]); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index f68f8a3e90a971b5e4a024feaf26ba498afc48da..28b9f8df9c873ee394eb6a241dd9ac06ba6c8796 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -392,26 +392,26 @@ Status ProcessInputs( EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { input_tensors->reserve(ninputs); for (int i = 0; i < ninputs; ++i) { - const Node& node = inputs[i].oper->node; + Node* node = &inputs[i].oper->node; int idx = inputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - input_tensors->emplace_back(&node, idx); + input_tensors->emplace_back(node, idx); - const auto& iter = input_nodes->find(&node); + const auto& iter = input_nodes->find(node); if (iter == input_nodes->end()) { - input_nodes->insert({&node, {idx}}); + input_nodes->insert({node, {idx}}); } else { auto& indices = iter->second; if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { - return InvalidArgument("TF_Output ", node.name(), ":", idx, + return InvalidArgument("TF_Output ", node->name(), ":", idx, " appears more than once in the input list"); } indices.push_back(idx); @@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { - const Node& node = outputs[i].oper->node; + Node* node = &outputs[i].oper->node; int idx = outputs[i].index; TF_RETURN_WITH_CONTEXT_IF_ERROR( - fn_body->graph.IsValidOutputTensor(&node, idx), + fn_body->graph.IsValidOutputTensor(node, idx), "Encountered while processing output ", i, " from function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), "Encountered while creating function '", fn_name, "'"); - output_tensors->emplace_back(&node, idx); + output_tensors->emplace_back(node, idx); } return Status::OK(); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 95652a11378d6276b5ba6540a07baa15aa77cc1c..5ba26d3c585350aa510f9970cbfc246a9a108543 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -25,6 +25,7 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -179,6 +180,15 @@ struct TF_ApiDefMap { tensorflow::mutex lock; }; +#ifndef __ANDROID__ +struct TF_Server { + TF_Server(std::unique_ptr server); + + const tensorflow::string target; + std::unique_ptr server; +}; +#endif + namespace tensorflow { class TensorCApi { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 03516c39dc970aa23967107d3a0446da94669465..d5934a10395ae094f65d3bc8b6cd7b94dbd32410 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" @@ -186,23 +187,40 @@ TEST(CAPI, LibraryLoadFunctions) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; - // Load the library. - TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - TF_Code code = TF_GetCode(status); - string status_msg(TF_Message(status)); - TF_DeleteStatus(status); - ASSERT_EQ(TF_OK, code) << status_msg; - - // Test op list. - TF_Buffer op_list_buf = TF_GetOpList(lib); - tensorflow::OpList op_list; - EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); - ASSERT_EQ(op_list.op_size(), 1); - EXPECT_EQ("TestCApi", op_list.op(0).name()); - - TF_DeleteLibraryHandle(lib); +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + // Load the library. + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op1.so", status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + + // Test op list. + TF_Buffer op_list_buf = TF_GetOpList(lib); + tensorflow::OpList op_list; + EXPECT_TRUE(op_list.ParseFromArray(op_list_buf.data, op_list_buf.length)); + ASSERT_EQ(op_list.op_size(), 1); + EXPECT_EQ("TestCApi1", op_list.op(0).name()); + TF_DeleteLibraryHandle(lib); + } +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) + { + TF_Buffer* op_list_buffer = TF_GetAllOpList(); + tensorflow::OpList op_list; + op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length); + ASSERT_GE(op_list.op_size(), 1); + typedef tensorflow::protobuf::RepeatedPtrField OpDefs; + const OpDefs& ops = op_list.op(); + bool found = std::find_if(ops.begin(), ops.end(), + [](const tensorflow::OpDef& op_def) { + return op_def.name() == "TestCApi"; + }) != ops.end(); + EXPECT_TRUE(found); + TF_DeleteBuffer(op_list_buffer); + } } void TestEncodeDecode(int line, const std::vector& data) { @@ -2329,15 +2347,9 @@ TEST(TestApiDef, TestCreateApiDef) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; + TF_Buffer* op_list_buf = TF_GetAllOpList(); TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - - TF_Buffer op_list_buf = TF_GetOpList(lib); - status = TF_NewStatus(); - auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2355,7 +2367,7 @@ TEST(TestApiDef, TestCreateApiDef) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); - TF_DeleteLibraryHandle(lib); + TF_DeleteBuffer(op_list_buf); } TEST(TestApiDef, TestCreateApiDefWithOverwrites) { @@ -2363,15 +2375,9 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { // tf_cuda_cc_test() bazel rule and remove the next line. if (!GPUDeviceName().empty()) return; + TF_Buffer* op_list_buf = TF_GetAllOpList(); TF_Status* status = TF_NewStatus(); - TF_Library* lib = - TF_LoadLibrary("tensorflow/c/test_op.so", status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - - TF_Buffer op_list_buf = TF_GetOpList(lib); - status = TF_NewStatus(); - auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + auto* api_def_map = TF_NewApiDefMap(op_list_buf, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -2400,7 +2406,7 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteBuffer(api_def_buf); TF_DeleteApiDefMap(api_def_map); - TF_DeleteLibraryHandle(lib); + TF_DeleteBuffer(op_list_buf); } class DummyKernel : public tensorflow::OpKernel { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 3ee31a6a7ac641bbd3fc4c05568b61e433a1d523..c34a84fcfee9b6ba9a7be86ae16e2856a2d343c7 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -50,6 +50,7 @@ tf_cuda_library( ], "//conditions:default": [], }) + [ + "@com_google_absl//absl/memory", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -69,7 +70,7 @@ tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], visibility = [ - "//learning/deepmind/courier:__pkg__", + "//learning/deepmind/courier:__subpackages__", "//tensorflow:internal", ], deps = [ @@ -143,6 +144,7 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3554ec0bf3202b54bfc38d67e51b89df19832302..027d752f420238da867cb9d8c116640e1730caaa 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/core/platform/host_info.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, std::unique_ptr* device_mgr) { - std::vector remote_devices; + std::vector> remote_devices; tensorflow::Status status; // TODO(nareshmodi) do this in parallel instead of serially. for (const string& remote_worker : remote_workers) { @@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices( status = s; if (s.ok()) { for (tensorflow::Device* d : *devices) { - remote_devices.push_back(d); + remote_devices.emplace_back(d); } } n.Notify(); @@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices( n.WaitForNotification(); } std::unique_ptr remote_device_mgr( - new tensorflow::DeviceMgr(remote_devices)); + new tensorflow::DeviceMgr(std::move(remote_devices))); TF_RETURN_IF_ERROR(status); @@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - std::vector devices; + std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( - new tensorflow::DeviceMgr(devices)); + new tensorflow::DeviceMgr(std::move(devices))); tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); @@ -404,8 +406,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { "The passed in handle is a nullptr"); return nullptr; } - tensorflow::Device* d = nullptr; - status->status = h->handle->OpDevice(&d); + tensorflow::Device* d = h->handle->op_device(); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); +} + +const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, + TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + tensorflow::Device* d = h->handle->device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } @@ -459,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { const char* name = op_or_function_name; // Shorthand const tensorflow::AttrTypeMap* types; - status->status = tensorflow::AttrTypeMapForOp(name, &types); - if (status->status.ok()) return new TFE_Op(ctx, name, types); - if (TF_GetCode(status) == TF_NOT_FOUND) { - if (ctx->context.FindFunctionByName(name)) { - status->status = tensorflow::Status::OK(); - return new TFE_Op(ctx, name, nullptr); + bool is_function = false; + status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); + if (status->status.ok()) { + if (is_function && !ctx->context.FindFunctionByName(name)) { + status->status = tensorflow::errors::NotFound( + "'", name, + "' is neither a type of a primitive operation nor a name " + "of a function registered in binary running on ", + tensorflow::port::Hostname(), + ". Make sure the operation or function is " + "registered in the binary running in this process."); + return nullptr; } + return new TFE_Op(ctx, name, is_function, types); } return nullptr; } @@ -498,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->operation.is_function()) { - status->status = tensorflow::errors::Unimplemented( - "TODO(apassos): Support for attributes for TensorFlow functions is not " - "ready yet."); - return TF_ATTR_INT; // The compiler requires that we return something. - } status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), attr_name, &ret, is_list); return ret; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index b2454d872207e26feb3764671474a5d87c01f84d..f80ae5a6d02d4d613c95cf8486e0fc0aeed3affc 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -48,7 +48,7 @@ extern "C" { typedef struct TFE_ContextOptions TFE_ContextOptions; // Return a new options object. -TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); // Set the config in TF_ContextOptions.options. // config should be a serialized tensorflow.ConfigProto proto. @@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status); + +// Returns the device of the operation that produced `h`. +// If `h` was produced by a copy, returns the destination device of +// the copy. Note that returned device name is not always the device +// holding the tensor handle's memory. If you want the latter, use +// TFE_TensorHandleBackingDeviceName. +// This function will block till the operation that produces `h` has completed. +// +// Device on which the kernel of the operation that produced `h` ran. +// +// If `h` was produced by a copy, returns the destination device of +// the copy. +// +// Note that returned device name is not always the device that owns the memory +// that backs the tensor handle. For the latter see +// TFE_TensorHandleBackingDeviceName. +// // This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Returns the name of the device in whose memory `h` resides. +// +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName( + TFE_TensorHandle* h, TF_Status* status); + // Return a pointer to a new TFE_TensorHandle that shares the underlying tensor // with `h`. On success, `status` is set to OK. On failure, `status` reflects // the error and a nullptr is returned. diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 5006b76f1981d068e99a2c081115ebb3a66d8c7f..52b0824552855860dfb138f3ac9a5d3afa7dc965 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -57,13 +57,9 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } - tensorflow::Device* device; - status->status = handle->handle->Device(&device); - if (!status->status.ok()) { - return nullptr; - } - #ifdef TENSORFLOW_EAGER_USE_XLA + tensorflow::Device* device = handle->handle->device(); + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. tensorflow::XlaDevice* xla_device = dynamic_cast(device); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 104d52430cf7aa14d4d2a335a1b96e667f21ce87..67bc1bcd24605f8363d6a7c8d5d6a0836a42fc82 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -79,10 +79,6 @@ struct TFE_TensorHandle { tensorflow::Device* op_device) : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} - TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - tensorflow::EagerContext* ctx) - : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} - TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} tensorflow::TensorHandle* handle; @@ -97,10 +93,9 @@ struct TFE_TensorDebugInfo { }; struct TFE_Op { - // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a - // primitive operation. - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : operation(&ctx->context, op, t) {} + TFE_Op(TFE_Context* ctx, const char* op, bool is_function, + const tensorflow::AttrTypeMap* t) + : operation(&ctx->context, op, is_function, t) {} tensorflow::EagerOperation operation; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 55331022b9dbd0696928fa44430f340f371432ac..6b39b79ee82f9c7baaf856e573a42b7da65691e5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "absl/strings/match.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" @@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) { TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); const int num_devices = TF_DeviceListCount(devices); + bool has_gpu0 = false; + bool has_gpu1 = false; + for (int i = 0; i < num_devices; ++i) { + const char* dev = TF_DeviceListName(devices, i, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + string device_name(dev); + if (device_name.find("GPU:0") != string::npos) { + has_gpu0 = true; + } + if (device_name.find("GPU:1") != string::npos) { + has_gpu1 = true; + } + } const char* kCPUDevice = "CPU:0"; - if (num_devices < 3) { + if (!has_gpu0 || !has_gpu1) { TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); @@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) { TF_SetStatus(status.get(), TF_OK, ""); + device_name = TFE_TensorHandleBackingDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + int num_dims = TFE_TensorHandleNumDims(h, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); ASSERT_EQ(num_dims, -1); @@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) { string(TF_Message(status.get()))); } +TEST(CAPI, TensorHandleDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name; + const char* backing_device_name = + TFE_TensorHandleBackingDeviceName(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name; + + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = + TFE_TensorHandleBackingDeviceName(retvals[0], status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0")) + << backing_device_name; + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContext(ctx); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 5607c9dcb0bbec72b2f86def3dd4e6590d73197b..bd38127d50c171af801dd1b937acefdba491b4a6 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -99,8 +99,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TFE_OpAddInput(op, b, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Shape", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); return op; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 474cae67c89249af3a62707f0db00ba458ca8f31..75ef9459e93b4f2ed471c423a34565594efc1714 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(); // Return a matmul op multiplying `a` by `b`. TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); +// Return a shape op fetching the shape of `a`. +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); + // Return an 1-D INT32 tensor containing a single value 1. TFE_TensorHandle* TestAxisTensorHandle(); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 5ba55a203ff70cc64c07e96b5a869a1f11c9334e..5c11f51e8749de84547ae873f5f55ebd42bc4b3d 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -141,8 +141,9 @@ class GradientTape { // null. The result is populated with one tensor per target element. Status ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_id, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result); @@ -396,6 +397,7 @@ template Status InitialGradients( const VSpace& vspace, gtl::ArraySlice target_tensor_ids, + gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, gtl::FlatMap>* result) { @@ -425,8 +427,13 @@ Status InitialGradients( "none of operations outputs match expected tensor"); } } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + // This target tensor was not generated by any operation recorded on + // the tape, so no gradient needs to be computed from it unless this + // target is also a source. + auto source_tensor = sources_that_are_targets.find(id); + if (source_tensor != sources_that_are_targets.end()) { + (*result)[id].push_back(vspace.Ones(source_tensor->second)); + } } } else { (*result)[id].push_back(output_gradients[i]); @@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template Status GradientTape::ComputeGradient( const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice source_tensor_ids, + const gtl::ArraySlice target_tensor_ids, + const gtl::ArraySlice source_tensor_ids, + const gtl::FlatMap sources_that_are_targets, gtl::ArraySlice output_gradients, std::vector* result) { gtl::FlatSet sources_set(source_tensor_ids.begin(), @@ -478,7 +486,8 @@ Status GradientTape::ComputeGradient( std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; - Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..07b9e8b940c55caf62ae0b81b884bf313d335459 --- /dev/null +++ b/tensorflow/c/env.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +struct TF_StringStream { + std::vector<::tensorflow::string>* list; + size_t position; +}; + +void TF_CreateDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->CreateDir(dirname)); +} + +void TF_DeleteDir(const char* dirname, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteDir(dirname)); +} + +void TF_DeleteRecursively(const char* dirname, int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, TF_Status* status) { + ::tensorflow::int64 f, d; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteRecursively(dirname, &f, &d)); + *undeleted_file_count = f; + *undeleted_dir_count = d; +} + +void TF_FileStat(const char* filename, TF_FileStatistics* stats, + TF_Status* status) { + ::tensorflow::FileStatistics cc_stats; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + stats->length = cc_stats.length; + stats->mtime_nsec = cc_stats.mtime_nsec; + stats->is_directory = cc_stats.is_directory; + } +} + +void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, + TF_Status* status) { + std::unique_ptr<::tensorflow::WritableFile> f; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Status s = + ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (s.ok()) { + *handle = reinterpret_cast(f.release()); + } +} + +void TF_CloseWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Close()); + delete cc_file; +} + +void TF_SyncWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Sync()); +} + +void TF_FlushWritableFile(TF_WritableFileHandle* handle, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, cc_file->Flush()); +} + +void TF_AppendWritableFile(TF_WritableFileHandle* handle, const char* data, + size_t length, TF_Status* status) { + auto* cc_file = reinterpret_cast<::tensorflow::WritableFile*>(handle); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, cc_file->Append(::tensorflow::StringPiece{data, length})); +} + +void TF_DeleteFile(const char* filename, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->DeleteFile(filename)); +} + +bool TF_StringStreamNext(TF_StringStream* list, const char** result) { + if (list->position >= list->list->size()) { + *result = nullptr; + return false; + } + + *result = list->list->at(list->position++).c_str(); + return true; +} + +void TF_StringStreamDone(TF_StringStream* list) { + delete list->list; + delete list; +} +TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { + auto* children = new std::vector<::tensorflow::string>; + + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetChildren(dirname, children)); + + auto* list = new TF_StringStream; + list->list = children; + list->position = 0; + return list; +} + +TF_StringStream* TF_GetLocalTempDirectories() { + auto* tmpdirs = new std::vector<::tensorflow::string>; + + ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); + + auto* list = new TF_StringStream; + list->list = tmpdirs; + list->position = 0; + return list; +} + +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { + return ::tensorflow::Env::Default()->NowNanos(); +} + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) { + return ::tensorflow::Env::Default()->NowMicros(); +} + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) { + return ::tensorflow::Env::Default()->NowSeconds(); +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h new file mode 100644 index 0000000000000000000000000000000000000000..9d27c5da37735042c7476b591e57486dbde33152 --- /dev/null +++ b/tensorflow/c/env.h @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +struct TF_WritableFileHandle; +struct TF_StringStream; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2206c6befd2167346c64032940d6e8c631e4a3e --- /dev/null +++ b/tensorflow/c/env_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/env.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +TEST(TestEnv, TestDirHandling) { + TF_StringStream* tempdirs = TF_GetLocalTempDirectories(); + const char* tempdir; + bool found = false; + while (TF_StringStreamNext(tempdirs, &tempdir)) { + found = true; + + TF_Status* s = TF_NewStatus(); + + ::tensorflow::string dirpath = + ::tensorflow::io::JoinPath(tempdir, "somedir"); + TF_CreateDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " + << TF_Message(s); + + ::tensorflow::string filepath = + ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + TF_WritableFileHandle* handle; + TF_NewWritableFile(filepath.c_str(), &handle, s); + ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " + << TF_Message(s); + + const char* data = "Hello, world!\n"; + TF_AppendWritableFile(handle, data, strlen(data), s); + ASSERT_TF_OK(s) << "TF_AppendWritableFile failed to append data to file at " + << filepath << ": " << TF_Message(s); + + TF_CloseWritableFile(handle, s); + ASSERT_TF_OK(s) << "TF_CloseWritableFile failed to close handle to " + << filepath << ": " << TF_Message(s); + + TF_StringStream* children = TF_GetChildren(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; + const char* childpath; + ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + // There should only be one file in this directory. + ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); + ASSERT_EQ(childpath, nullptr); + TF_StringStreamDone(children); + + TF_FileStatistics stats; + TF_FileStat(filepath.c_str(), &stats, s); + ASSERT_EQ(stats.length, strlen(data)); + ASSERT_FALSE(stats.is_directory); + ASSERT_GT(stats.mtime_nsec, 0); + + // Trying to delete a non-empty directory should fail. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_NE(TF_OK, TF_GetCode(s)) + << "TF_DeleteDir unexpectedly succeeded with a non-empty directory " + << dirpath; + + TF_DeleteFile(filepath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteFile failed for " << filepath << ": " + << TF_Message(s); + + // Now deleting the directory should work. + TF_DeleteDir(dirpath.c_str(), s); + ASSERT_TF_OK(s) << "TF_DeleteDir failed for " << dirpath << ": " + << TF_Message(s); + + TF_DeleteStatus(s); + break; + } + + ASSERT_TRUE(found) << "expected at least one temp dir"; + + TF_StringStreamDone(tempdirs); +} + +TEST(TestEnv, TestTimeFunctions) { + ASSERT_GE(TF_NowSeconds(), 946684800); // Midnight Jan 1, 2000 + ASSERT_GE(TF_NowMicros(), 946684800 * 1e6); + ASSERT_GE(TF_NowNanos(), 946684800 * 1e9); +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a4eaecb6cf2740a522b1e849d1306ebde6c4577 --- /dev/null +++ b/tensorflow/c/kernels.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +// This file forms the basis of a stable ABI for third-party kernel +// implementations. It is crucial that changes to this file are made cautiously +// and with a focus on maintaining both source and binary compatibility. + +struct TF_KernelBuilder { + ::tensorflow::KernelDefBuilder* cc_builder; + + void* (*create_function)(TF_OpKernelConstruction*); + void (*compute_function)(void*, TF_OpKernelContext*); + void (*delete_function)(void*); +}; + +TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) { + TF_KernelBuilder* result = new TF_KernelBuilder; + result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name); + result->cc_builder->Device(device_name); + result->create_function = create_func; + result->compute_function = compute_func; + result->delete_function = delete_func; + return result; +} + +void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) { + DCHECK_NE(builder, nullptr); + delete builder->cc_builder; + delete builder; +} + +namespace tensorflow { +namespace { + +// An OpKernel whose methods delegate to C function pointers. +class COpKernel : public OpKernel { + public: + explicit COpKernel(OpKernelConstruction* ctx, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)) + : OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) { + if (create_func != nullptr) { + c_kernel_ = + (*create_func)(reinterpret_cast(ctx)); + } else { + c_kernel_ = nullptr; + } + } + + void Compute(OpKernelContext* ctx) override { + (*compute_func_)(c_kernel_, reinterpret_cast(ctx)); + } + + ~COpKernel() override { + if (delete_func_ != nullptr) { + (*delete_func_)(c_kernel_); + } + } + + private: + void (*compute_func_)(void*, TF_OpKernelContext* context); + void (*delete_func_)(void*); + void* c_kernel_; +}; + +// A KernelFactory that returns COpKernel instances. +class KernelBuilderFactory + : public ::tensorflow::kernel_factory::OpKernelFactory { + public: + explicit KernelBuilderFactory(TF_KernelBuilder* builder) + : builder_(builder) {} + ::tensorflow::OpKernel* Create( + ::tensorflow::OpKernelConstruction* context) override { + return new ::tensorflow::COpKernel(context, builder_->create_function, + builder_->compute_function, + builder_->delete_function); + } + ~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); } + + private: + TF_KernelBuilder* builder_; +}; +} // namespace +} // namespace tensorflow + +void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, + TF_Status* status) { + using tensorflow::register_kernel::Name; + + tensorflow::kernel_factory::OpKernelRegistrar( + builder->cc_builder->Build(), name, + absl::make_unique(builder)); + + TF_SetStatus(status, TF_OK, ""); +} + +int TF_NumInputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_inputs(); +} + +int TF_NumOutputs(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + return cc_ctx->num_outputs(); +} + +void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i)); + TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status); + if (TF_GetCode(status) == TF_OK) { + *tensor = result; + } +} + +void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (i < 0 || i >= cc_ctx->num_inputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + return; + } + ::tensorflow::Tensor cc_tensor; + ::tensorflow::Status s = ::tensorflow::TF_TensorToTensor(tensor, &cc_tensor); + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status(status, s); + if (s.ok()) { + cc_ctx->set_output(i, cc_tensor); + } +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..1a91aa184f11ac8e45b38a1d106c7b445747a7c1 --- /dev/null +++ b/tensorflow/c/kernels.h @@ -0,0 +1,118 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_H_ +#define TENSORFLOW_C_KERNELS_H_ + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// C API for TensorFlow Kernels. +// +// This API allows developers to register custom kernel implementations for +// TensorFlow. +// +// See c_api.h header comments for a discussion about API conventions. +// +// Users wishing to extend TensorFlow with new kernels will call +// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with +// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided +// kernels when necessary. + +struct TF_KernelBuilder; +struct TF_OpKernelConstruction; +struct TF_OpKernelContext; + +// Allocates a new kernel builder and returns a pointer to it. +// +// If non-null, TensorFlow will call create_func when it needs to instantiate +// the kernel. The pointer returned by create_func will be passed to +// compute_func and delete_func, thereby functioning as a "this" pointer for +// referring to kernel instances. +// +// The TF_OpKernelConstruction pointer passed to create_func is owned by +// TensorFlow and will be deleted once create_func returns. It must not be used +// after this. +// +// When TensorFlow needs to perform a computation with this kernel, it will +// call compute_func. This function will receive the pointer returned by +// create_func (or null if no create_func was provided), along with the inputs +// to the computation. +// +// The TF_OpKernelContext pointer received by compute_func is owned by +// TensorFlow and will be deleted once compute_func returns. It must not be used +// after this. +// +// Finally, when TensorFlow no longer needs the kernel, it will call +// delete_func if one is provided. This function will receive the pointer +// returned in `create_func` or nullptr if no `create_func` was provided. +// +// The caller should pass the result of this function to +// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for +// some reason, the kernel builder will not be registered, the caller should +// delete it with TF_DeleteKernelBuilder. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This call takes ownership of the `builder` pointer. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, + TF_KernelBuilder* builder, + TF_Status* status); + +// Deletes the given TF_KernelBuilder. This should be called only if the kernel +// builder is not registered with TensorFlow via TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); + +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_H_ diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e659ee3c3d258a626ccf03a782ec031b5a703a48 --- /dev/null +++ b/tensorflow/c/kernels_test.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +struct MyCustomKernel { + bool created; + bool compute_called; +}; + +static bool delete_called = false; + +static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + return s; +} + +static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) { + struct MyCustomKernel* s = static_cast(kernel); + s->compute_called = true; +} + +static void MyDeleteFunc(void* kernel) { + struct MyCustomKernel* s = static_cast(kernel); + EXPECT_TRUE(s->created); + EXPECT_TRUE(s->compute_called); + delete_called = true; + delete s; +} + +namespace tensorflow { + +static std::unique_ptr GetFakeKernel(const char* device_name, + const char* op_name, + Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_device(device_name); + def.add_input("input1"); + def.add_input("input2"); + return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1, + status); +} + +// Tests registration of a single C kernel and checks that calls through the +// C/C++ boundary are being made. +TEST(TestKernel, TestRegisterKernelBuilder) { + const char* kernel_name = "SomeKernelName"; + const char* op_name = "FooOp"; + const char* device_name = "FakeDeviceName1"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + KernelList list; + list.ParseFromArray(buf->data, buf->length); + ASSERT_EQ(1, list.kernel_size()); + ASSERT_EQ(device_name, list.kernel(0).device_type()); + TF_DeleteBuffer(buf); + TF_DeleteStatus(status); + } + + { + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + } + + ASSERT_TRUE(delete_called); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool RequiresRecordingAccessedTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST(TestKernel, TestInputAndOutputCount) { + const char* kernel_name = "InputOutputCounterKernel"; + const char* op_name = "BarOp"; + const char* device_name = "FakeDeviceName2"; + + REGISTER_OP(op_name) + .Input("input1: double") + .Input("input2: uint8") + .Output("output1: uint8"); + + static int num_inputs = 0; + static int num_outputs = 0; + + // A kernel whose Compute function has a side-effect of updating num_inputs + // and num_outputs. Various functions on TF_OpKernelContext are also + // exercised. + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + num_inputs = TF_NumInputs(ctx); + num_outputs = TF_NumOutputs(ctx); + + TF_Tensor* input = nullptr; + TF_Status* s = TF_NewStatus(); + TF_GetInput(ctx, 0, &input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s); + EXPECT_EQ(123, *static_cast(TF_TensorData(input))); + TF_GetInput(ctx, -1, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + TF_GetInput(ctx, 3, &input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + // Copy the input tensor to output. + TF_SetOutput(ctx, 0, input, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_SetOutput(ctx, 24, input, s); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)); + + TF_DeleteStatus(s); + if (input != nullptr) { + TF_DeleteTensor(input); + } + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr, false); + p.device = &dummy_device; + + Tensor t(tensorflow::uint8(123)); + + gtl::InlinedVector inputs; + // Simulate 2 inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + + ASSERT_EQ(2, num_inputs); + ASSERT_EQ(1, num_outputs); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 247236b760dd8c07bbb08426100b6a4d34296d2e..98d8393332269ae349cf8aa5c0b612c6f17172e6 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); } +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst, "adding input tensor"); + } +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 5cce84020bc68d912d259f51512341eb5f464a2c..44779ca656165dd65590cb5e9ea3ccf71165ed63 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +// Updates 'dst' to consume 'new_src'. void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); @@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); // because I couldn't get SWIG to work otherwise. void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/test_op1.cc b/tensorflow/c/test_op1.cc new file mode 100644 index 0000000000000000000000000000000000000000..b22cc9aef2b344282f45340ff12ee849935a26f9 --- /dev/null +++ b/tensorflow/c/test_op1.cc @@ -0,0 +1,23 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +REGISTER_OP("TestCApi1").Doc(R"doc(Used to test C API)doc"); + +} // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index b587e63227708427e7fae47f8f4a7b524d963ed9..a09becc49b10d2c58f98fbcc11df5190f794c1d4 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -170,6 +170,7 @@ cc_library_with_android_deps( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -411,6 +412,7 @@ tf_cc_test( srcs = ["gradients/nn_grad_test.cc"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":grad_testutil", ":gradient_checker", @@ -453,11 +455,33 @@ tf_cc_test( ], ) +# Generates separate libraries for array_ops and math_ops to reduce the dependency count of targets that depend on only these tf_gen_op_wrappers_cc( - name = "cc_ops", + name = "math_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + op_lib_names = [ + "math_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "array_ops", api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], op_lib_names = [ "array_ops", + ], + pkg = "//tensorflow/core", +) + +tf_gen_op_wrappers_cc( + name = "cc_ops", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + deps_internal = [ + ":array_ops_internal", + ":math_ops_internal", + ], + op_lib_names = [ "audio_ops", "candidate_sampling_ops", "control_flow_ops", @@ -465,10 +489,10 @@ tf_gen_op_wrappers_cc( "image_ops", "io_ops", "linalg_ops", + "list_ops", "logging_ops", "lookup_ops", "manip_ops", - "math_ops", "nn_ops", "no_op", "parsing_ops", @@ -480,10 +504,23 @@ tf_gen_op_wrappers_cc( "user_ops", ], other_hdrs = [ + "ops/array_ops.h", "ops/const_op.h", + "ops/math_ops.h", "ops/standard_ops.h", ], + other_hdrs_internal = [ + "ops/array_ops_internal.h", + "ops/math_ops_internal.h", + ], pkg = "//tensorflow/core", + deps = [ + ":array_ops", + ":const_op", + ":math_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + ], ) tf_cc_test( diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 6abc9e268e3ac97379954a34017ddffa010db67f..81785b2d89b3d36b46992b7ae376b5175a806027 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -95,6 +95,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -112,6 +113,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -135,6 +137,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -167,6 +170,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -183,6 +187,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -200,6 +205,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, kernel_label_(kernel_label), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -217,6 +223,7 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(other.impl()->assigned_device_), + xla_cluster_(other.impl()->xla_cluster_), colocation_constraints_( clear_colocations ? std::unordered_set() @@ -237,6 +244,25 @@ Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), assigned_device_(assigned_device), + xla_cluster_(other.impl()->xla_cluster_), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + +Scope::Impl::Impl(const Scope& other, Tags::XlaCluster, + const string& xla_cluster) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), + xla_cluster_(xla_cluster), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -326,6 +352,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->assigned_device_.empty()) { builder->AssignedDevice(impl()->assigned_device_); } + if (!impl()->xla_cluster_.empty()) { + builder->XlaCluster(impl()->xla_cluster_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -388,7 +417,7 @@ Scope Scope::NewSubScope(const string& child_scope_name) const { false /* copy_names */)); } -Scope Scope::WithOpName(const string& op_name) const { +Scope Scope::WithOpNameImpl(const string& op_name) const { if (impl()->single_use_scope()) { UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name, " on this scope")); @@ -425,6 +454,10 @@ Scope Scope::WithAssignedDevice(const string& assigned_device) const { return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); } +Scope Scope::WithXlaCluster(const string& xla_cluster) const { + return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index e307d8989b6647dfac8d2691ed2171c86b7f3a7c..0a75f23725c143e6b22ee6dffae1428ed8209fe8 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -69,8 +70,9 @@ struct CompositeOpScopes; /// // W will be named "linear/W" /// auto W = Variable(linear.WithOpName("W"), /// {2, 2}, DT_FLOAT); -/// // b will be named "linear/b" -/// auto b = Variable(linear.WithOpName("b"), +/// // b will be named "linear/b_3" +/// int idx = 3; +/// auto b = Variable(linear.WithOpName("b_", idx), /// {2}, DT_FLOAT); /// auto x = Const(linear, {...}); // name: "linear/Const" /// auto m = MatMul(linear, x, W); // name: "linear/MatMul" @@ -113,8 +115,11 @@ class Scope { Scope NewSubScope(const string& child_scope_name) const; /// Return a new scope. All ops created within the returned scope will have - /// names of the form `name/op_name[_suffix]`. - Scope WithOpName(const string& op_name) const; + /// names of the form `name/StrCat(fragments...)[_suffix]` + template + Scope WithOpName(Ty... fragments) const { + return WithOpNameImpl(absl::StrCat(fragments...)); + } /// Return a new scope. All ops created within the returned scope will have as /// control dependencies the union of operations in the control_deps vector @@ -137,6 +142,10 @@ class Scope { /// their assigned device set to `assigned_device`. Scope WithAssignedDevice(const string& assigned_device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their _XlaCluster attribute set to `xla_cluster`. + Scope WithXlaCluster(const string& xla_cluster) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for @@ -227,6 +236,8 @@ class Scope { // END_SKIP_DOXYGEN private: + Scope WithOpNameImpl(const string& op_name) const; + friend class InternalScope; std::unique_ptr impl_; explicit Scope(Impl*); diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 514e02e84146b6d95147d83182e5d9a07509cfa1..5db7eab2b819c2c5d8fc358953d4607848f1cba5 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -61,6 +61,7 @@ class Scope::Impl { enum class KernelLabel; enum class Colocate; enum class AssignedDevice; + enum class XlaCluster; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -78,6 +79,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); + Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster); std::unordered_set GetColocationConstraints( const Operation& colocate_with_op) const; @@ -112,6 +114,7 @@ class Scope::Impl { const string kernel_label_ = ""; const string device_ = ""; const string assigned_device_ = ""; + const string xla_cluster_ = ""; const std::unordered_set colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 588e96cb196189780037f66266484962ba0385e4..2a32a2ed6f7862a29f4ce3d1aba5fdbc86adc670 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper); +Status LeakyReluGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper); + +Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + float alpha; + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha)); + internal::LeakyReluGrad::Attrs attrs; + auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1), + attrs.Alpha(alpha)); + grad_outputs->push_back(dx); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper); + Status EluGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index aa72cf7ba2a958f54d50b59f0edaefb27edf0e86..f5a09e09dcda3e06c71d44d5fa5a1b121a9ade58 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/gradients/grad_testutil.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) { RunTest(x, x_init_value, y, shape); } +TEST_F(NNGradTest, LeakyReluGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = ops::internal::LeakyRelu(scope_, x); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + RunTest(x, x_init_value, y, shape); +} + +TEST_F(NNGradTest, LeakyReluGradGrad) { + TensorShape shape({5, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // Avoid input values where Leaky ReLU gradient is not well defined (around + // zero). + Tensor x_init_value = test::AsTensor( + {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2}); + Tensor features = test::AsTensor( + {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f}, + {5, 2}); + auto y = ops::internal::LeakyReluGrad(scope_, x, features); + RunTest(x, x_init_value, y, shape); +} + TEST_F(NNGradTest, EluGrad) { TensorShape shape({5, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 3d3895c8fa82c3c0e2974228e9cad767d0e00df4..52345a376cc29ee47ccb9888c9bb26292468b5a9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -133,5 +133,6 @@ filegroup( "testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_main_op/**", "testdata/half_plus_two/**", + "testdata/half_plus_two_v2/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 645a3f101d1ae7dda88ec4ca622c694dc5a7a919..6f00dc324bd7054b28de2c35023581e1666bfa01 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; /// SavedModel text format proto filename. constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; -/// SavedModel legacy init op key. +/// SavedModel legacy init op collection key. Used in v1 SavedModels. constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; -/// SavedModel main op key. +/// SavedModel main op collection key. Used in v1 SavedModels. constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; /// Directory in which to save the SavedModel variables. @@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables"; /// SavedModel variables filename. constexpr char kSavedModelVariablesFilename[] = "variables"; +/// SavedModel SignatureDef keys for the initialization and train ops. Used in +/// V2 SavedModels. +constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; +constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op"; + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e..85d3dd01fa51b3c3ba6fcbf5faac03f1ff5630e2 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -122,34 +122,54 @@ Status RunOnce(const RunOptions& run_options, return run_status; } -bool HasMainOp(const MetaGraphDef& meta_graph_def) { +// RunInitOp will return OK if the initialization op was run successfully. +// An empty init_op_name indicates that there are no init ops to run. +Status RunInitOp(const RunOptions& run_options, const string& export_dir, + const MetaGraphDef& meta_graph_def, + const std::vector& asset_file_defs, + Session* session, const string& init_op_name) { + if (!init_op_name.empty()) { + LOG(INFO) << "Running initialization op on SavedModel bundle."; + std::vector> inputs; + AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); + RunMetadata run_metadata; + return RunOnce(run_options, inputs, {}, {init_op_name}, + nullptr /* outputs */, &run_metadata, session); + } + return Status::OK(); +} + +// A SavedModel may store the name of the initialization op to run in the +// in the SignatureDef (v2) or a collection (v1). If an init_op collection +// exists, then the collection must contain exactly one op. +Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, + string* init_op_name) { + const auto& sig_def_map = meta_graph_def.signature_def(); + const auto& init_op_sig_it = + meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey); + if (init_op_sig_it != sig_def_map.end()) { + *init_op_name = init_op_sig_it->second.outputs() + .find(kSavedModelInitOpSignatureKey) + ->second.name(); + return Status::OK(); + } + const auto& collection_def_map = meta_graph_def.collection_def(); + string init_op_collection_key; if (collection_def_map.find(kSavedModelMainOpKey) != collection_def_map.end()) { - return true; + init_op_collection_key = kSavedModelMainOpKey; + } else { + init_op_collection_key = kSavedModelLegacyInitOpKey; } - return false; -} -Status RunMainOp(const RunOptions& run_options, const string& export_dir, - const MetaGraphDef& meta_graph_def, - const std::vector& asset_file_defs, - Session* session, const string& main_op_key) { - LOG(INFO) << "Running MainOp with key " << main_op_key - << " on SavedModel bundle."; - const auto& collection_def_map = meta_graph_def.collection_def(); - const auto main_op_it = collection_def_map.find(main_op_key); - if (main_op_it != collection_def_map.end()) { - if (main_op_it->second.node_list().value_size() != 1) { + const auto init_op_it = collection_def_map.find(init_op_collection_key); + if (init_op_it != collection_def_map.end()) { + if (init_op_it->second.node_list().value_size() != 1) { return errors::FailedPrecondition( strings::StrCat("Expected exactly one main op in : ", export_dir)); } - std::vector> inputs; - AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); - RunMetadata run_metadata; - const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {string(main_op_name)}, - nullptr /* outputs */, &run_metadata, session); + *init_op_name = init_op_it->second.node_list().value(0); } return Status::OK(); } @@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, std::vector* asset_file_defs) { + // With SavedModel v2, we write asset file def into metagraph instead of + // collection, so read from metagraph first. + if (meta_graph_def.asset_file_def_size() > 0) { + for (const auto& asset : meta_graph_def.asset_file_def()) { + asset_file_defs->push_back(asset); + } + return Status::OK(); + } + // Fall back to read from collection to be backward compatible with v1. const auto& collection_def_map = meta_graph_def.collection_def(); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); if (assets_it == collection_def_map.end()) { @@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(), asset_file_defs, bundle->session.get())); - if (HasMainOp(bundle->meta_graph_def)) { - TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, - bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelMainOpKey)); - } else { - TF_RETURN_IF_ERROR(RunMainOp( - run_options, export_dir, bundle->meta_graph_def, asset_file_defs, - bundle->session.get(), kSavedModelLegacyInitOpKey)); - } + string init_op_name; + TF_RETURN_IF_ERROR( + GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); + TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, + asset_file_defs, bundle->session.get(), + init_op_name)); return Status::OK(); } diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2..597e42bb65ab5536664089f7e65ec52d77fc8f23 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] = "cc/saved_model/testdata/half_plus_two_main_op/00000123"; constexpr char kTestDataSharded[] = "cc/saved_model/testdata/half_plus_two/00000123"; +constexpr char kTestDataInitOpV2[] = + "cc/saved_model/testdata/half_plus_two_v2/00000123"; class LoaderTest : public ::testing::Test { protected: @@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) { EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); } +TEST_F(LoaderTest, SavedModelInitOpV2Format) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + CheckSavedModelBundle(export_dir, bundle); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9ff036688007836524129e23f5cf82edd1e8910 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt @@ -0,0 +1 @@ +asset-file-contents \ No newline at end of file diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..a10bbf8fb6bca0fcee6414b2927d2f706de85ebc Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..15b75d6ef6bffc336d138d923badb3928b8c4c13 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index new file mode 100644 index 0000000000000000000000000000000000000000..7ec9fb4fe2dd21d0a6c324aecd7658fc37cf2326 Binary files /dev/null and b/tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index differ diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df80ec01245a7fe820c79d5879458c4cd0a93cb --- /dev/null +++ b/tensorflow/compat_template_v1.__init__.py @@ -0,0 +1,34 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import as _absolute_import +from __future__ import division as _division +from __future__ import print_function as _print_function + +import os as _os + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import + +from tensorflow.python.tools import component_api_helper as _component_api_helper +_component_api_helper.package_hook( + parent_package_str=__name__, + child_package_str=('tensorflow_estimator.python.estimator.api.estimator')) + +# API IMPORTS PLACEHOLDER + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 6c29f09cde7ee17c11cb44ce48d8e9128daae4d0..16151e77737429f4fbf690fc34b12a70bacebdc4 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -93,7 +93,7 @@ cc_library( ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index b17bc658fa06b9feb7edb292bd89ef31e6309169..ab1c1be344e2257721507543bc7647d4ff4becb2 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, +Status GenArgMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); if (config.feed_size() != num_args) { @@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, } for (int i = 0; i < num_args; ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); + TF_RETURN_IF_ERROR( + AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites)); const string code = R"( - void set_arg{{NAME}}_data(void* data) { + void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { @@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, // Generate methods for results (outputs). Status GenResultMethods(const tf2xla::Config& config, - const xla::ProgramShape& ps, string* methods) { + const xla::ProgramShapeProto& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config, } for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector> rewrites; - TF_RETURN_IF_ERROR( - AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); + TF_RETURN_IF_ERROR(AddRewritesForShape( + i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { return static_cast<{{TYPE}}*>(result_data({{I}})); @@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, ExtractEntryParamBufferInfos(buffer_infos); std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); - const xla::ProgramShape& ps = compile_result.program_shape; + const xla::ProgramShapeProto& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); @@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } @@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, - {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, @@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts, Status GenerateMetadata(const CodegenOpts& opts, const CompileResult& compile_result, MetadataResult* metadata_result) { - std::unique_ptr program_shape; + std::unique_ptr program_shape; if (opts.gen_program_shape) { program_shape = - absl::make_unique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save @@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts, // a shim that evaluates to nullptr, which is what we want. ProtobufToEmbed program_shape_protobuf{ - CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", - program_shape.get()}; + CreateUniqueIdentifier(opts, "ProgramShapeProto"), + "xla::ProgramShapeProto", program_shape.get()}; ProtobufToEmbed hlo_profile_printer_data_protobuf{ CreateUniqueIdentifier(opts, "HloProfilePrinterData"), diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 90410c46a8e36e44454f1219ad76d0fb0937070d..9485e86b10e225a3c9c12eafd9905bdf7c15c9fa 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -57,7 +57,7 @@ struct MetadataResult { std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the - // xla::ProgramShape instance for the CompileResult passed to + // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. string program_shape_access_shim; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index bb288d23000527be74f01630d20bbf82e50007ce..c1788ca32a1d099284eeb870f9513891051fd29e 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) { BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, 5, {})); - compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( - { - xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), - xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); + compile_result.program_shape = + xla::ShapeUtil::MakeProgramShape( + { + xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), + xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + }, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index e4d8a02877c75fa72c5747650ab9c7ac229955b3..968afad65ed6d4b5510687df484b7ce6743f6a85 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -22,7 +22,7 @@ extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, tensorflow::int64* profile_counters); -extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[]; namespace foo { @@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // with dim indices specifying which value. No bounds checking is performed // on dim indices. - void set_arg0_data(void* data) { + void set_arg0_data(const void* data) { set_arg_data(0, data); } float* arg0_data() { @@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg_myfeed_data(void* data) { + void set_arg_myfeed_data(const void* data) { set_arg_data(0, data); } float* arg_myfeed_data() { @@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { arg_data(0)))[dim0][dim1]; } - void set_arg1_data(void* data) { + void set_arg1_data(const void* data) { set_arg_data(1, data); } tensorflow::int64* arg1_data() { @@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = []() { - xla::ProgramShape* proto = new xla::ProgramShape; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = []() { + xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index eb001c5d45bdfefc76629d7303d89f5480432235..ce8e5ec8c96a2c3696f14b8eea206d648182ecb5 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b5f97b34cd928d32eb220536342c715d91d45bb..9fc223bdc7c0e207ce2005cb86250aa77e709df8 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client, return errors::Unknown("Couldn't get XLA program shape: ", pshape_or.status().error_message()); } - compile_result->program_shape = *pshape_or.ValueOrDie(); - xla::ProgramShape* pshape = &compile_result->program_shape; - std::vector arg_layouts; - arg_layouts.reserve(pshape->parameters_size()); + compile_result->program_shape = pshape_or.ValueOrDie()->ToProto(); + xla::ProgramShapeProto* pshape = &compile_result->program_shape; + + // AotXlaComputationInstance::argument_layouts is a vector of Shape + // pointers. Accumulate the Shape objects themselves in a separate vector + // while building the vector of pointers. + std::vector arg_layout_ptrs(pshape->parameters_size()); + std::vector arg_layouts(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { - arg_layouts.push_back(pshape->mutable_parameters(i)); + arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i)); + arg_layout_ptrs[i] = &arg_layouts[i]; } xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; - instance.argument_layouts = std::move(arg_layouts); - instance.result_layout = &pshape->result(); + instance.argument_layouts = std::move(arg_layout_ptrs); + xla::Shape result_shape(pshape->result()); + instance.result_layout = &result_shape; xla::StatusOr>> aot_or = client->CompileAheadOfTime({instance}, aot_opts); if (!aot_or.ok()) { diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index e03c5b1aa77c1262ed903aae3072ef65f34d80a2..ee7bb26fabd2d897b85b62f38778ecbfe2238eb6 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -33,9 +33,9 @@ namespace tfcompile { struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; - xla::ProgramShape program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. - int pointer_size = 0; // Size of a pointer in bytes. + xla::ProgramShapeProto program_shape; // Static shape of args and results. + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. }; // CompileGraph compiles the graph_def into an object file containing a function diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f10852c7850f61bfd8b99fa9f1648202d182085e..4dd79e5882d7da61be029735ef2b165908c599f9 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) { // muladd has the program shape defined. MatMulAndAddComp muladd; - const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape(); ASSERT_TRUE(muladd_shape != nullptr); ASSERT_EQ(muladd_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); - EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2)); + EXPECT_TRUE( + ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2)); - const xla::Shape& muladd_result = muladd_shape->result(); + const xla::Shape muladd_result(muladd_shape->result()); ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); const xla::Shape& muladd_result0 = diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 859c84bb91657422b830255b0217f8946d351458..2dc3e8c9113b37bf9d575ad66783f4ab49478af4 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -390,6 +390,7 @@ def target_llvm_triple(): "//tensorflow:android_arm": "armv7-none-android", "//tensorflow:android_arm64": "aarch64-none-android", "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:ios": "arm64-none-ios", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index b95b063348c5cdfdcaed635ba527e9f0bfd6092d..d548de8c44285f6d21dd778db464a31e1b19645b 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -103,7 +103,7 @@ Status Main(const MainFlags& flags) { return errors::InvalidArgument("Must specify --cpp_class"); } codegen_opts.gen_hlo_profile_printer_data = - xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + xla::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); @@ -132,7 +132,7 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 661b444a42eefadf52739d84483e8e26c07fadf5..15dcbb2641eca031e82db9aa58dee6a14ab0a2cc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -21,10 +21,8 @@ package( ) load("//tensorflow:tensorflow.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -39,7 +37,7 @@ cc_library( ":xla_cpu_device", ":xla_cpu_jit", "//tensorflow/compiler/plugin", - ] + if_cuda_is_configured([ + ] + if_cuda([ ":xla_gpu_device", ":xla_gpu_jit", ]), @@ -52,6 +50,8 @@ cc_library( deps = [ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -65,6 +65,7 @@ cc_library( ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), alwayslink = 1, @@ -75,15 +76,17 @@ cc_library( srcs = ["xla_cpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep + ":flags", ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -93,6 +96,7 @@ cc_library( srcs = ["xla_gpu_device.cc"], visibility = [":friends"], deps = [ + ":create_xla_launch_op", # buildcleaner: keep ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_ops", @@ -101,6 +105,8 @@ cc_library( "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) @@ -116,7 +122,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -188,11 +194,13 @@ cc_library( "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:stack", "//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", ], ) @@ -205,6 +213,18 @@ cc_library( # Internal targets below this point. +cc_library( + name = "flags", + srcs = ["flags.cc"], + hdrs = ["flags.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/xla:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + cc_library( name = "common", srcs = [ @@ -237,6 +257,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) @@ -249,6 +271,8 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/core:core_cpu", @@ -259,6 +283,22 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "xla_compilation_cache_test", + srcs = [ + "xla_compilation_cache_test.cc", + ], + deps = [ + ":xla_compilation_cache", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -324,7 +364,6 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -360,6 +399,83 @@ tf_cc_test( ], ) +cc_library( + name = "shape_inference", + srcs = ["shape_inference.cc"], + hdrs = ["shape_inference.h"], + deps = [ + ":shape_inference_helpers", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "shape_inference_test", + srcs = ["shape_inference_test.cc"], + deps = [ + ":shape_inference", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:ops", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/kernels:constant_op", + ], +) + +cc_library( + name = "encapsulate_util", + srcs = ["encapsulate_util.cc"], + hdrs = ["encapsulate_util.h"], + deps = [ + ":shape_inference", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "encapsulate_util_test", + srcs = ["encapsulate_util_test.cc"], + deps = [ + ":encapsulate_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "compilation_passes", srcs = [ @@ -368,6 +484,8 @@ cc_library( "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "encapsulate_xla_computations_pass.cc", + "extract_outside_compilation_pass.cc", + "increase_dynamism_for_auto_jit_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -377,12 +495,16 @@ cc_library( "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", + "extract_outside_compilation_pass.h", + "increase_dynamism_for_auto_jit_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", ], deps = [ ":common", + ":encapsulate_util", + ":flags", ":shape_inference_helpers", ":union_find", ":xla_cluster_util", @@ -390,12 +512,13 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -409,8 +532,10 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -435,25 +560,6 @@ cc_library( hdrs = ["union_find.h"], ) -cc_library( - name = "producer_consumer_queue", - hdrs = ["producer_consumer_queue.h"], - deps = ["//tensorflow/core:lib"], -) - -tf_cc_test( - name = "producer_consumer_queue_test", - size = "small", - srcs = ["producer_consumer_queue_test.cc"], - deps = [ - ":producer_consumer_queue", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -491,30 +597,39 @@ tf_cc_test( "build_xla_ops_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", + "extract_outside_compilation_pass_test.cc", + "increase_dynamism_for_auto_jit_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":encapsulate_util", ":node_matchers", ":xla_cluster_util", + ":xla_cpu_device", ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -551,31 +666,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "xla_launch_util_test", - size = "small", - srcs = ["xla_launch_util_test.cc"], - deps = [ - ":common", - ":xla_compilation_cache", - ":xla_launch_util", - ":xla_tensor", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core/kernels:variable_ops", - ], -) - cc_library( name = "xla_fusion_optimizer", srcs = ["xla_fusion_optimizer.cc"], @@ -621,6 +711,7 @@ cc_library( deps = [ "//tensorflow/cc:ops", "//tensorflow/compiler/xla:test", + "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", @@ -636,6 +727,7 @@ tf_cc_test( deps = [ ":node_matchers", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:ops", "//tensorflow/core:ops", "//tensorflow/core:test_main", @@ -648,7 +740,10 @@ tf_custom_op_py_library( visibility = [ ":friends", ], - deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], + deps = [ + "//tensorflow/compiler/jit/ops:xla_ops_grad", + "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py", + ], ) # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 5974696b7751d69eb27141173fdab14313925ee9..9f4042630edaec1b9519b6434d859a48372e8b15 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -15,10 +15,16 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -48,6 +54,88 @@ void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { } } +// Returns a data value that is dead iff `control` is dead. +Output ControlToData(const Scope& scope, Node* control) { + Output data = ops::Const(scope.WithOpName("ctrl_as_data"), + Tensor(DT_BOOL, TensorShape({0}))); + scope.graph()->AddControlEdge(control, data.node()); + return Output(data.node()); +} + +// Returns an operation that can be control-depended on that is dead iff `data` +// is dead. +Operation DataToControl(const Scope& scope, Output data) { + return Operation( + ops::Identity(scope.WithOpName("data_as_ctrl"), data).node()); +} + +// Replaces each outgoing edge from `old_node` with a merge node that merges in +// the corresponding output from `new_node`. +void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } + + std::vector merged_outputs(old_node->num_outputs(), Output(nullptr)); + + std::vector data_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(data_edges), + [](const Edge* e) { return !e->IsControlEdge(); }); + + for (const Edge* e : data_edges) { + int oidx = e->src_output(); + Output merged_output = merged_outputs[oidx]; + if (merged_output.node() == nullptr) { + ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), + {Output(old_node, oidx), Output(new_node, oidx)}); + merged_output = merged_outputs[oidx] = merge_op.output; + } + + Node* dst = e->dst(); + int dst_idx = e->dst_input(); + + s.graph()->RemoveEdge(e); + s.graph()->AddEdge(merged_output.node(), merged_output.index(), dst, + dst_idx); + } +} + +// Replaces each control successor of `old_node` to execute whenever either +// `old_node` or `new_node` is executed. +void MergeOutgoingControlEdges(const Scope& s, Node* old_node, Node* new_node) { + if (!s.status().ok()) { + return; + } + + std::vector ctrl_edges; + absl::c_copy_if(old_node->out_edges(), std::back_inserter(ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + + if (ctrl_edges.empty()) { + return; + } + + // We can't merge control edges directly so we instead first "convert" them to + // normal values that can be merged, merge the values and then "convert" the + // merged value back into control. + // + // NB! We need to copy out the outgoing control edges before constructing + // old_ctrl_as_data otherwise the control edge from old_node to the constant + // in ControlToData will be present in ctrl_edges. + + Output old_ctrl_as_data = ControlToData(s, old_node); + Output new_ctrl_as_data = ControlToData(s, new_node); + + ops::Merge ctrl_merge_as_data(s.WithOpName("ctrl_merge"), + {old_ctrl_as_data, new_ctrl_as_data}); + Operation ctrl_merge = DataToControl(s, ctrl_merge_as_data.output); + + for (const Edge* e : ctrl_edges) { + s.graph()->AddControlEdge(ctrl_merge.node(), e->dst()); + s.graph()->RemoveControlEdge(e); + } +} + struct XlaClusterInfo { std::vector constant_inputs; std::vector non_constant_inputs; @@ -107,7 +195,39 @@ Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { return Status::OK(); } -Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { +void RemoveAllIncomingControlEdges(Graph* g, Node* n) { + std::vector incoming_ctrl_edges; + absl::c_copy_if(n->in_edges(), std::back_inserter(incoming_ctrl_edges), + [](const Edge* e) { return e->IsControlEdge(); }); + for (const Edge* e : incoming_ctrl_edges) { + g->RemoveControlEdge(e); + } +} + +// Returns true (into `result`) if `node` must be compiled. +Status NodeRequiresCompilation(Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + const XlaOpRegistry::DeviceRegistration* registration = nullptr; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return errors::Internal("Could not find compilation device ", + device_type.type()); + } + *result = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + return Status::OK(); +} + +Status ReplaceNodeWithXlaCompileAndXlaRun( + const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, + Graph* g, Node* n) { + bool requires_compilation; + TF_RETURN_IF_ERROR(NodeRequiresCompilation(n, &requires_compilation)); + if (!lazy_compilation_enabled) { + requires_compilation = true; + } + Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) .NewSubScope(n->name()) @@ -121,18 +241,63 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { /*constants=*/cluster_info.constant_inputs, /*args=*/cluster_info.non_constant_inputs, /*resources=*/cluster_info.resource_inputs, + /*must_compile=*/requires_compilation, cluster_info.function); TF_RETURN_IF_ERROR( CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); - std::vector xla_run_args = cluster_info.non_constant_inputs; - absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); - ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, - xla_compile.key, n->output_types()); + if (requires_compilation) { + // "Strict" compilation: every _XlaCompile invocation must compile the + // cluster. + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); + + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + g->RemoveNode(n); + } else { + // "Lazy" compilation: an _XlaCompile invocation may decide not to compile + // the cluster based on profitability heuristics. - MoveOutgoingEdges(g, /*old_node=*/n, - /*new_node=*/xla_run.operation.node()); - g->RemoveNode(n); + // We generate the following graph: + // + // (use_tf_call, use_xla_run) = + // Switch(pred=xla_compile.compilation_successful, + // value=xla_compile.key) + // + // tf_call_outputs = cluster_N(..., ^use_tf_call) + // xla_run_outputs = _XlaRun(..., key=use_xla_run) + // outputs = Merge(tf_call_outputs, xla_run_outputs). + ops::Switch s(root.WithOpName("predicated_compilation_key"), + xla_compile.key, xla_compile.compilation_successful); + Output predicated_compilation_key = s.output_true; + Output inverse_predicated_compilation_key = s.output_false; + + std::vector xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, + std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + predicated_compilation_key, n->output_types()); + + MergeOutgoingControlEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + + MergeOutgoingDataEdges(root, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + + TF_RETURN_IF_ERROR(root.status()); + + // We already have a TensorFlow function call into the cluster -- the + // original node we set out to rewrite. We just wire in the correct control + // deps and we're done. + RemoveAllIncomingControlEdges(g, n); + g->AddControlEdge( + DataToControl(root, inverse_predicated_compilation_key).node(), n); + n->ClearAttr(kXlaCompiledKernelAttr); + } return Status::OK(); } @@ -141,22 +306,34 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } + // Copy out the nodes we want to rewrite to avoid modifying the graph while we + // iterate on graph->op_nodes(). + std::vector xla_compiled_kernels; + absl::c_copy_if(graph->op_nodes(), std::back_inserter(xla_compiled_kernels), + [](const Node* n) { + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + return false; + } - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n)); - } + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + return IsXlaCompiledKernel(*n); + }); + + bool lazy_compilation_enabled = + enable_lazy_compilation_ + ? *enable_lazy_compilation_ + : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation; + + for (Node* n : xla_compiled_kernels) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( + *options.flib_def, lazy_compilation_enabled, graph, n)); } if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h index 1dd38fa95186dfbe458166caa23a131fbe3c9510..58f7c4b3a0d1472f602e8234f9f08c23dfe78a34 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ #define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#include "absl/types/optional.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" @@ -25,7 +26,17 @@ namespace tensorflow { // executes (using XLA) TF function calls marked with "_XlaCompiledKernel". class BuildXlaOpsPass : public GraphOptimizationPass { public: + // If enable_lazy_compilation is not nullopt then *enable_lazy_compilation + // overrides --tf_xla_enable_lazy_compilation flag in deciding whether lazy + // compilation is enabled. + explicit BuildXlaOpsPass( + absl::optional enable_lazy_compilation = absl::nullopt) + : enable_lazy_compilation_(enable_lazy_compilation) {} + Status Run(const GraphOptimizationPassOptions& options) override; + + private: + absl::optional enable_lazy_compilation_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 9d56db7b6bc12938b2de9df02b97ff0ca6a42e54..48a23a4c1711ac88a329723c46559112d5a39dbd 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -22,18 +22,38 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace { +class BuildXlaOpsTest : public ::testing::Test { + protected: + void SetUp() override { + // This is needed to register the XLA_* devices. + CHECK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices_) + .ok()); + } + + private: + std::vector> devices_; +}; + using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::Attr; using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::Inputs; using ::tensorflow::testing::matchers::NodeWith; using ::tensorflow::testing::matchers::Op; +using ::tensorflow::testing::matchers::Out; +using ::testing::_; Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { auto graph = absl::make_unique(OpRegistry::Global()); @@ -42,15 +62,18 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr* result) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : graph->nodes()) { - if (n->assigned_device_name().empty()) { + if (n->requested_device().empty()) { n->set_assigned_device_name(kCpuDevice); + } else { + n->set_assigned_device_name(n->requested_device()); } } GraphOptimizationPassOptions opt_options; opt_options.graph = &graph; - BuildXlaOpsPass pass; + BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true); TF_RETURN_IF_ERROR(pass.Run(opt_options)); + VLOG(3) << graph->ToGraphDefDebug().DebugString(); *result = std::move(graph); return Status::OK(); } @@ -76,16 +99,19 @@ Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, result); } -Node* MakeWrite(const Scope& scope, const string& id) { - Output var_handle = - ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); - Output value_to_write = - ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); - ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, - value_to_write); +Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { + Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT, + TensorShape({})); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id), + var_handle, value_to_write); return assign_op.operation.node(); } +Node* MakeWrite(const Scope& scope, const string& id) { + return MakeWrite( + scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id); +} + FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { FunctionDefLibrary flib_def; FunctionDef func = FunctionDefHelper::Create( @@ -97,14 +123,16 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return flib_def; } -TEST(BuildXlaOps, ControlDepsPreserved) { - Scope root = Scope::NewRootScope().ExitOnError(); +TEST_F(BuildXlaOpsTest, ControlDepsPreserved) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); FunctionDefLibrary flib_def = CreateFunctionDefLibWithConstFunction("cluster_0"); TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* call; TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); Node* write_op = MakeWrite(root, "write"); root.graph()->AddControlEdge(call, write_op); @@ -116,15 +144,17 @@ TEST(BuildXlaOps, ControlDepsPreserved) { EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); } -TEST(BuildXlaOps, CleanFailureOnBogusAttr) { +TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { Scope root = Scope::NewRootScope().ExitOnError(); FunctionDefLibrary flib_def = CreateFunctionDefLibWithConstFunction("cluster_0"); TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; TF_ASSERT_OK( MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + Node* write_op = MakeWrite(root, "write"); root.graph()->AddControlEdge(call, write_op); @@ -134,5 +164,65 @@ TEST(BuildXlaOps, CleanFailureOnBogusAttr) { EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); } +TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + auto xla_compile = NodeWith(Op("_XlaCompile"), Attr("must_compile", false)); + auto predicated_compilation_key = + NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile))); + auto xla_run = + NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key))); + auto tf_call = + NodeWith(Op("cluster_0"), + CtrlDeps(NodeWith(Op("Identity"), + Inputs(Out(0, predicated_compilation_key))))); + auto merge = NodeWith(Op("Merge"), Inputs(Out(tf_call), Out(xla_run))); + auto assign_var = NodeWith(Op("AssignVariableOp"), Inputs(_, Out(merge))); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} + +TEST_F(BuildXlaOpsTest, OnXlaDevice) { + const char* kXlaDeviceName = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + Scope root = Scope::NewRootScope().WithDevice(kXlaDeviceName).ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + call->set_requested_device(kXlaDeviceName); + TF_ASSERT_OK(root.DoShapeInference(call)); + + Node* write_op = MakeWrite(root, Output(call), "write_result"); + + std::unique_ptr graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + auto xla_op = + NodeWith(Op("_XlaRun"), Inputs(Out(NodeWith(Op("_XlaCompile"))))); + auto assign_var = + NodeWith(Op("AssignVariableOp"), Inputs(Out(NodeWith()), Out(xla_op))); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, assign_var); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index 73866607621cd745f6e640a14405daebf0dd9985..0f872a480f4d4843217f1df3452c4dc62531264e 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test { SessionOptions options; auto* device_count = options.config.mutable_device_count(); device_count->insert({"CPU", 1}); + std::vector> devices; TF_CHECK_OK(DeviceFactory::AddDevices( - options, "/job:localhost/replica:0/task:0", &devices_)); + options, "/job:localhost/replica:0/task:0", &devices)); FunctionDefLibrary proto; for (const auto& fdef : flib) { @@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test { lib_def_ = absl::make_unique( OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = absl::make_unique(devices_); + device_mgr_ = absl::make_unique(std::move(devices)); pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); @@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test { } FunctionLibraryRuntime* flr_; - std::vector devices_; std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr pflr_; diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index b7ae7fbeb3912882368dc828e8d6fcd50735b04e..0562838f628c66b1eb03af9d2a5139c01dca31c5 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -525,7 +525,6 @@ Predicate* PredicateFactory::MakeAndOrImpl( op->GetOperands().begin(), op->GetOperands().end()); } else { - std::vector sub_ops_intersection; common_inner_operands.clear(); absl::c_copy_if(op->GetOperands(), std::back_inserter(common_inner_operands), @@ -696,8 +695,8 @@ Status CreateMultipleNextIterationInputsError(Node* merge) { } } return errors::InvalidArgument( - "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), - ": \n", absl::StrJoin(backedges, "\n"), + "Multiple NextIteration inputs to merge node ", + FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"), "\nMerge nodes can have at most one incoming NextIteration edge."); } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 617e31488c7daeb714c0ff7056b786e4eaf7873f..8a73101c184e6190921fd7729742922bd96f4bcf 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -127,7 +127,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, Output loop_cond = ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), latch.output_true, increment_by); Output next_iteration = @@ -191,7 +192,8 @@ DependentInductionVar CreateDependentLoopInvariantValue( value, frame_name); ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); - ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), + latch.output_false); Output next_iteration = ops::NextIteration( root.WithOpName(prefix + "/next_iteration"), latch.output_true); CHECK(root.graph() diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index da27f837e88fc3f57f865211929ec9cb1a1af779..f478832781cb1dc045d9163d4a6f5e5f64a8a705 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1109,6 +1109,9 @@ Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; + // Verify that the graph has well-formed control flow structure. + std::vector dummy; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy)); TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); if (VLOG_IS_ON(1)) { @@ -1119,8 +1122,11 @@ Status Encapsulator::Subgraph::BuildFunctionDef( fdef); } - if (!reuse_existing_functions || library->Find(name) == nullptr) { + const FunctionDef* original_fdef = library->Find(name); + if (!reuse_existing_functions || original_fdef == nullptr) { TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } else if (!FunctionDefsEqual(*original_fdef, fdef)) { + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); } return Status::OK(); } @@ -1531,9 +1537,6 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); - // Verify that the graph has well-formed control flow structure. - std::vector dummy; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 49958093b8dcf35e8adcdfd2f7dfce8558d5db6f..de89be9a3555960dabe7bacd17226c15ae888ae6 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,16 +16,20 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -406,8 +410,8 @@ Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { Node* KeyPlaceholder(const string& call_node, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("Placeholder"), "Placeholder", - opts.op_registry()); + NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"), + "Placeholder", opts.op_registry()); TensorShapeProto shape; shape.add_dim()->set_size(2); return opts.WithAttr("shape", shape) @@ -494,7 +498,8 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { return opts.FinalizeBuilder(&node_builder); } -Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { +Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, + const std::vector& encapsulated_functions) { Status s; // Convert the GraphDef to a Graph std::unique_ptr lib_def( @@ -505,11 +510,39 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); if (!s.ok()) return s; + s = PerformStaticShapeInferenceBeforeEncapsulation( + graph.get(), "_encapsulate", "_outside"); + if (!s.ok()) return s; + + s = PreprocessForEncapsulation(graph.get(), "_encapsulate", "_outside"); + if (!s.ok()) return s; + std::unique_ptr graph_out; - s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, - /*rewrite_subgraph_fn=*/{}, - /*reuse_existing_functions=*/false, - &graph_out, lib_def.get()); + s = EncapsulateSubgraphsInFunctions( + "_encapsulate", /*outside_compilation_attribute=*/"", *graph, + /*rewrite_subgraph_fn=*/{}, + /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); + if (!s.ok()) return s; + + std::unordered_map clusters; + for (const auto& func : encapsulated_functions) { + Node* xla_computation_node; + for (Node* n : graph_out->nodes()) { + if (n->name() == func) { + xla_computation_node = n; + } + } + if (!xla_computation_node) { + return errors::Internal("Cannot find node ", func); + } + NameAttrList func_name_attrs; + func_name_attrs.set_name(func); + clusters.emplace(func, + XlaClusterInfo{func, func_name_attrs, xla_computation_node, + std::map{}}); + } + s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, + graph_out.get(), lib_def.get()); if (!s.ok()) return s; GraphDef graphdef_out; @@ -520,6 +553,11 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { return s; } +Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { + std::vector encapsulated_functions; + return Encapsulate(graphdef, library, encapsulated_functions); +} + // If there are no marked nodes, funcification should be a no-op. TEST(EncapsulateSubgraphsTest, NoFunctions) { GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); @@ -703,7 +741,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, + "_cluster", "", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); @@ -755,7 +793,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", "_outside", graph_before, + "_encapsulate", "", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](const std::vector& arg_source_tensors, std::unique_ptr* graph_ptr, @@ -800,7 +838,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); int guaranteed_consts = 0; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_encapsulate", "_outside", graph_before, + "_encapsulate", "", graph_before, /*rewrite_subgraph_fn=*/ [&guaranteed_consts](const std::vector& arg_source_tensors, std::unique_ptr* graph_ptr, @@ -854,15 +892,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape.opts()); Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), @@ -877,7 +915,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, @@ -899,7 +937,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -975,15 +1013,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, shape1.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), @@ -998,8 +1036,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape2.opts()); Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT, DT_FLOAT}, shape2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), @@ -1020,7 +1057,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { } *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, @@ -1037,14 +1074,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", - absl::Span({"outside_compilation_O1_host_compute"})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, - {"F", "outside_compilation_O1_host_compute"}}, + {"F"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, @@ -1058,7 +1094,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"i_0_retval", "I:o:0"}}); + {{"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1149,33 +1185,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1", "F2"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; - { - GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); - Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); - TF_EXPECT_OK( - AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); - } - TensorShapeProto shape_proto_expected; shape_proto_expected.add_dim()->set_size(2); *library_expected.add_function() = FunctionDefHelper::Create( "F1", {"a_0_arg:float", "b_0_arg:float"}, - {"f_0_retval:float", "d_0_retval:float"}, {}, + {"f_0_retval_retval:float", "d_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1191,19 +1212,19 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", absl::Span({})}, + {"shape_inference_graph", ""}, + {"shapes", + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); + {{"d_0_retval_retval", "D:o:0"}, {"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"e_0_arg:float", "f_0_arg:float"}, - {"g_0_retval:float", "i_0_retval:float"}, {}, + "F2", {"f_0_arg:float", "bridge_e_g_0_arg:float"}, + {"i_0_retval_retval:float", "g_0_retval_retval:float"}, {}, { - {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, + {{"G"}, "BinaryTest", {"bridge_e_g_0_arg", "f_0_arg"}}, {{"I"}, "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, @@ -1219,7 +1240,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); + {{"i_0_retval_retval", "I:o:0"}, {"g_0_retval_retval", "G:o:0"}}); { std::unique_ptr lib_def( @@ -1265,11 +1286,11 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); - node_builder2.Input(e).Input(call1); + node_builder2.Input(call1).Input(e); Node* call2 = b2.opts() .WithControlInputs({s2, e, call1}) .FinalizeBuilder(&node_builder2); - Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); + Binary(ops::NodeOut(call2, 1), call2, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1312,44 +1333,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1", "F2"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; - - { - GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", - {DT_FLOAT, DT_FLOAT}, shape.opts()); - Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts() - .WithName("E") - .WithAttr("_encapsulate", "F1") - .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); - TF_EXPECT_OK( - AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); - } - - { - GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F2", "O1", - {DT_FLOAT}, shape.opts()); - Node* h = Unary(recv, shape.opts() - .WithName("H") - .WithAttr("_encapsulate", "F2") - .WithAttr("_outside", "O1")); - SendFromHost(ops::NodeOut(key_constant, 0), "F2", "O1", {h}, shape.opts()); - TF_EXPECT_OK( - AddGraphDefToFunctionLibrary(shape, "F2_O1", &library_expected)); - } + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1365,16 +1358,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", absl::Span({})}, + {"shape_inference_graph", ""}, + {"shapes", + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {}, { {{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}}, {{"I"}, @@ -1387,12 +1380,12 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, - {"shape_inference_graph", - "_outside_compilation_shape_inference_F2_O1"}, - {"shapes", absl::Span({})}, + {"shape_inference_graph", ""}, + {"shapes", + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"i_0_retval", "I:o:0"}}); + {{"i_0_retval_retval", "I:o:0"}}); { std::unique_ptr lib_def( @@ -1439,9 +1432,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(a).Input(b); - Node* call2 = b2.opts() - .WithControlInputs({s2, call1}) - .FinalizeBuilder(&node_builder2); + Node* call2 = + b2.opts().WithControlInputs({s2}).FinalizeBuilder(&node_builder2); Binary(call1, call2, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1473,7 +1465,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; @@ -1482,7 +1475,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { shape_proto_expected.add_dim()->set_size(2); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1501,7 +1494,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1557,7 +1550,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; @@ -1566,7 +1560,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { shape_proto_expected.add_dim()->set_size(2); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1586,7 +1580,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1644,13 +1638,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1666,7 +1661,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1721,13 +1716,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1747,7 +1743,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -1811,15 +1807,15 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape2.opts()); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT}, shape2.opts()); Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() @@ -1832,7 +1828,7 @@ TEST(EncapsulateSubgraphsTest, } *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1852,7 +1848,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, - {{"h_0_retval", "H:o:0"}}); + {{"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -1920,15 +1916,15 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, shape1.opts()); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() @@ -1941,7 +1937,7 @@ TEST(EncapsulateSubgraphsTest, } *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, @@ -1961,7 +1957,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, - {{"h_0_retval", "H:o:0"}}); + {{"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2034,15 +2030,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* key_constant = KeyPlaceholder("F1", shape1.opts()); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, shape1.opts()); Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() @@ -2055,7 +2051,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { } *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval_retval:float"}, {}, {{{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, @@ -2076,28 +2072,24 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", - absl::Span({"outside_compilation_O1_host_compute"})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", ""}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, - {"outside_compilation_O1_host_compute"}}, + {}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", - absl::Span({"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", ""}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, - {"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"}}}, - {{"h_0_retval", "H:o:0"}}); + {}}}, + {{"h_0_retval_retval", "H:o:0"}}); { std::unique_ptr lib_def( @@ -2169,19 +2161,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"C"}, "UnaryTest", {"a_0_arg"}}, {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( @@ -2234,19 +2227,20 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + std::vector encapsulated_functions{"F1"}; + TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; GraphDef graphdef_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* key_constant = - KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); - Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1")); + Node* key_constant = KeyPlaceholder("F1", shape.opts()); Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {DT_FLOAT}, shape.opts()); - Node* e = BinaryUnknownShape(known, recv, + Node* a = InputShaped(shape.opts().WithName("A")); + Node* c = Unary(a, shape.opts().WithName("C")); + Node* e = BinaryUnknownShape(c, recv, shape.opts() .WithName("E") .WithAttr("_encapsulate", "F1") @@ -2258,7 +2252,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, + "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {}, { {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, {{"F"}, @@ -2279,7 +2273,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, - {{"f_0_retval", "F:o:0"}}); + {{"f_0_retval_retval", "F:o:0"}}); { std::unique_ptr lib_def( diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f4b9c90a4ff0b1166cdb7b5942771b350740ef3 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -0,0 +1,955 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" + +namespace tensorflow { + +namespace { + +// Returns string attribute value for the node if the attribute is present, +// otherwise returns empty optional value. +absl::optional GetStringAttr(const Node& n, const string& attr_name) { + auto attr = n.attrs().Find(attr_name); + if (!attr) { + return absl::nullopt; + } else { + return attr->s(); + } +} + +// Adds a value to the node's list attribute. +template +Status AppendToListAttr(Node* n, const string& attr_name, const string& value) { + std::vector attr_value; + Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value); + if (!s.ok() && s.code() != error::NOT_FOUND) { + return s; + } + + n->ClearAttr(attr_name); + attr_value.push_back(value); + n->AddAttr(attr_name, attr_value); + return Status::OK(); +} + +// Replaces attribute value. +template +void ReplaceAttr(Node* n, const string& attr_name, const T& value) { + n->ClearAttr(attr_name); + n->AddAttr(attr_name, value); +} + +// Step 1a ~ 1d for PreprocessForEncapsulation(). See comments of +// PreprocessForEncapsulation() for details. +Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + // Gather edges to remove. We should not remove the edge while iterating. + std::vector edges_to_remove; + for (const Edge* e : g->edges()) { + if (!e->IsControlEdge()) { + continue; + } + + auto src_xla_computation = + GetStringAttr(*e->src(), xla_computation_attr_name); + auto dst_xla_computation = + GetStringAttr(*e->dst(), xla_computation_attr_name); + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (!src_xla_computation && !dst_xla_computation) { + continue; + } else if (src_xla_computation && !dst_xla_computation) { + if (src_outside_compilation) { + // Case 1c: outside compilation to host computation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); + } + } else if (!src_xla_computation && dst_xla_computation) { + if (dst_outside_compilation) { + // Case 1c: host computation control to outside compilation edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); + } + } else { // src_xla_computation && dst_xla_computation + if (*src_xla_computation != *dst_xla_computation) { + if (src_outside_compilation && dst_outside_compilation) { + // Case 1b: outside compilation to outside compilation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); + } else if (src_outside_compilation && !dst_outside_compilation) { + // Case 1a: outside compilation to another XLA computaition control + // edge. + TF_RETURN_IF_ERROR(AppendToListAttr( + e->src(), kXlaConnectedToOtherXlaComputationAttrName, + *dst_xla_computation)); + } else if (!src_outside_compilation && dst_outside_compilation) { + // Case 1a: another XLA computaition to outside compilation control + // edge. + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, + *src_xla_computation)); + } + } + } + } + + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + return Status::OK(); +} + +// Step 2 for PreprocessForEncapsulation(). See comments of +// PreprocessForEncapsulation() for details. +Status ProcessXlaToXlaDataEdges(Graph* g, + const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + // Gather edges between XLA computations. Notice that we do not store `Edge*` + // directly because we remove some nodes while adding Identity nodes, and + // those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_xla_computation = + GetStringAttr(*e->src(), xla_computation_attr_name); + auto dst_xla_computation = + GetStringAttr(*e->dst(), xla_computation_attr_name); + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + if (!src_xla_computation || !dst_xla_computation) { + continue; + } + + if (*src_xla_computation != *dst_xla_computation) { + if (src_outside_compilation || dst_outside_compilation) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); + VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); + } + } + } + + // For each XLA -> XLA edge, add an Identity node between src and dst. + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Create Identity node, and connect it between `src` and `dst`. + string identity_node_name = + absl::StrCat("bridge_", src->name(), "_", dst->name()); + DataType dtype = src->output_type(src_output); + TF_ASSIGN_OR_RETURN(Node * identity_node, + BuildIdentityNode(g, identity_node_name, dtype, src, + /*requested_device=*/absl::nullopt)); + identity_node->AddAttr(kBridgeSourceNodeAttrName, src->name()); + g->AddEdge(src, src_output, identity_node, 0); + g->AddEdge(identity_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = identity_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with corresponding + // edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 3 for PreprocessForEncapsulation(). See comments of +// PreprocessForEncapsulation() for details. +Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + // Gather edges between outside compilation and host computation. Notice that + // we do not store `Edge*` directly because we remove some nodes while adding + // Identity nodes, and those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + bool is_host_to_outside_compilation; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + if (e->src()->attrs().Find(xla_computation_attr_name) == nullptr && + e->dst()->attrs().Find(xla_computation_attr_name) != nullptr && + e->dst()->attrs().Find(outside_compilation_attr_name) != nullptr) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), + /*is_host_to_outside_compilation=*/true}); + VLOG(4) << "Host -> oc edge: " << e->DebugString(); + } else if (e->dst()->attrs().Find(xla_computation_attr_name) == nullptr && + e->src()->attrs().Find(xla_computation_attr_name) != nullptr && + e->src()->attrs().Find(outside_compilation_attr_name) != + nullptr) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id(), + /*is_host_to_outside_compilation=*/false}); + VLOG(4) << "Oc -> host edge: " << e->DebugString(); + } + } + + // Remove the edge from host to outside compilation. Add a placeholder as + // outside compilation node input. + std::map, Node*> placeholders; + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Find or create placeholder node. + string new_name = + edges[i].is_host_to_outside_compilation + ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output) + : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); + Node* placeholder_node; + if (iter == placeholders.end()) { + NodeDefBuilder placeholder_builder(new_name, "Placeholder"); + placeholder_builder.Attr("dtype", src->output_type(src_output)); + if (edges[i].is_host_to_outside_compilation) { + placeholder_builder.Attr(kHostToOutsideCompilationOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kHostToOutsideCompilationSrcOutputAttrName, + src_output); + // If this placeholder node is in outside compilation, we need to set + // `xla_computation_attr_name` and `outside_compilation_attr_name`. + string xla_computation_attr, outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), xla_computation_attr_name, + &xla_computation_attr)); + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), + outside_compilation_attr_name, + &outside_compilation_attr)); + placeholder_builder.Attr(xla_computation_attr_name, + xla_computation_attr); + placeholder_builder.Attr(outside_compilation_attr_name, + outside_compilation_attr); + } else { + placeholder_builder.Attr(kOutsideCompilationToHostOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kOutsideCompilationToHostSrcOutputAttrName, + src_output); + } + NodeDef placeholder_def; + TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); + Status s; + placeholder_node = g->AddNode(placeholder_def, &s); + TF_RETURN_IF_ERROR(s); + placeholders[placeholder_index] = placeholder_node; + } else { + placeholder_node = iter->second; + } + g->AddEdge(placeholder_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = placeholder_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with corresponding + // edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 1 for `PostprocessForEncapsulation`. See comments of +// `PostprocessForEncapsulation` for details. +Status RemovePlaceholderBetweenOutsideCompilationAndHostComputation(Graph* g) { + // Gather all outside compilation to host computation nodes. + struct PlaceHolderNodeInfo { + Node* n; + bool is_host_to_oc; + }; + std::vector placeholder_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Placeholder") { + if (HasNodeAttr(n->def(), + kOutsideCompilationToHostOriginalNodeAttrName)) { + placeholder_nodes.push_back({n, false}); + } else if (HasNodeAttr(n->def(), + kHostToOutsideCompilationOriginalNodeAttrName)) { + placeholder_nodes.push_back({n, true}); + } + } + } + + // Remove the placeholder nodes, and reconnect original edge. + auto node_name_index = g->BuildNodeNameIndex(); + for (auto placeholder_iter : placeholder_nodes) { + Node* n = placeholder_iter.n; + + string node_name; + int node_src_output; + if (placeholder_iter.is_host_to_oc) { + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, + &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), + kHostToOutsideCompilationSrcOutputAttrName, + &node_src_output)); + } else { + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kOutsideCompilationToHostOriginalNodeAttrName, + &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), + kOutsideCompilationToHostSrcOutputAttrName, + &node_src_output)); + } + auto iter = node_name_index.find(node_name); + if (iter == node_name_index.end()) { + return errors::Internal( + "Cannot find original node for oc -> host placeholder node ", + node_name); + } + + // Change all usage node to use the original node instead. + Node* original_node = iter->second; + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(original_node, e->dst()); + g->RemoveEdge(e); + } + for (int i = 0; i < data_edges.size(); i++) { + Node* dst = data_edges[i].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[i].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(original_node->name(), ":", node_src_output); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(original_node, node_src_output, replace_node, dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int j = i + 1; j < data_edges.size(); j++) { + if (data_edges[j].dst == dst) { + data_edges[j].dst = replace_node; + } + } + + // Other placeholder node might have `dst` as original node. Update + // `node_name_index` with `replace_node`. + node_name_index[replace_node->name()] = replace_node; + } + + // Remove placeholder node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 2 for `PostprocessForEncapsulation`. See comments of +// `PostprocessForEncapsulation` for details. +Status RemoveIdentityBetweenDifferentXlaComputation(Graph* g) { + // Gather Identity nodes to remove. + std::vector bridge_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Identity" && + HasNodeAttr(n->def(), kBridgeSourceNodeAttrName)) { + bridge_nodes.push_back(n); + } + } + + // Remove the identity nodes, and reconnect the original edge. + for (int i = 0; i < bridge_nodes.size(); i++) { + Node* n = bridge_nodes[i]; + const Edge* src_edge = nullptr; + TF_RETURN_IF_ERROR(n->input_edge(0, &src_edge)); + + // Change all usage node to use the original node instead. + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(src_edge->src(), e->dst()); + g->RemoveEdge(e); + } + for (int j = 0; j < data_edges.size(); j++) { + Node* dst = data_edges[j].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[j].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(src_edge->src()->name(), ":", src_edge->src_output()); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(src_edge->src(), src_edge->src_output(), replace_node, + dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int k = j + 1; k < data_edges.size(); k++) { + if (data_edges[k].dst == dst) { + data_edges[k].dst = replace_node; + } + } + + // The node we replaced might be in `bridge_nodes`. If so, update + // `bridge_nodes` to use the replaced node. + for (int k = i + 1; k < bridge_nodes.size(); k++) { + if (bridge_nodes[k] == dst) { + bridge_nodes[k] = replace_node; + } + } + } + + // Remove Identity node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 3 for `PostprocessForEncapsulation`. See comments of +// `PostprocessForEncapsulation` for details. +// We do not need to worry about removed nodes in step 1 and 2; +// `PreprocessForEncapsulation` will not record control dependencies for those +// remvoed nodes in the first place. +Status AddControlDependencies( + Graph* g, const std::unordered_map& cluster_node_names) { + auto node_name_index = g->BuildNodeNameIndex(); + + // Reconnect outside compilation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaControlDependenciesAttrName, &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaControlDependenciesAttrName); + for (const string& control_input : control_deps) { + auto iter = node_name_index.find(control_input); + if (iter == node_name_index.end()) { + return errors::Internal("Cannot find original node for ", + control_input); + } + g->AddControlEdge(iter->second, n); + } + } + } + + // Reconnect outside compilation to XLA computation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = GetNodeAttr( + n->attrs(), kXlaConnectedToOtherXlaComputationAttrName, &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaConnectedToOtherXlaComputationAttrName); + for (const string& control_input : control_deps) { + auto iter = cluster_node_names.find(control_input); + if (iter == cluster_node_names.end()) { + return errors::Internal("Cannot find cluster node for ", + control_input); + } + auto iter2 = node_name_index.find(iter->second); + if (iter2 == node_name_index.end()) { + return errors::Internal("Cannot find cluster node for ", + iter->second); + } + g->AddControlEdge(n, iter2->second); + } + } + } + + // Reconnect XLA computation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaConnectedFromOtherXlaComputationAttrName, + &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaConnectedFromOtherXlaComputationAttrName); + for (const string& control_input : control_deps) { + auto iter = cluster_node_names.find(control_input); + if (iter == cluster_node_names.end()) { + return errors::Internal("Cannot find cluster node for ", + control_input); + } + auto iter2 = node_name_index.find(iter->second); + if (iter2 == node_name_index.end()) { + return errors::Internal("Cannot find cluster node for ", + iter->second); + } + g->AddControlEdge(iter2->second, n); + } + } + } + + return Status::OK(); +} + +// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges to remove. We should not remove the edge while iterating. + std::vector edges_to_remove; + for (const Edge* e : g->edges()) { + if (!e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation) { + if (*src_outside_compilation != *dst_outside_compilation) { + // Case 1a: outside compilation to outside compilation control edge. + edges_to_remove.push_back(e); + + TF_RETURN_IF_ERROR(AppendToListAttr( + e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, + e->src()->name())); + } + } else if (src_outside_compilation && !dst_outside_compilation) { + // Case 1b: outside compilation to its XLA computation control edge. + ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true); + } else if (!src_outside_compilation && dst_outside_compilation) { + // Case 1b: XLA computation to outside compilation in it control edge. + ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true); + } + } + + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + return Status::OK(); +} + +// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of +// `PreprocessEdgesBetweenOutsideCompilations` for details. +Status PreprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather edges between outside compilation and host computation. Notice that + // we do not store `Edge*` directly because we remove some nodes while adding + // Identity nodes, and those Edge pointers might be invalidated. + struct EdgeInfo { + int dst_input, dst_node_id; + }; + std::vector edges; + for (const Edge* e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + + if (src_outside_compilation && dst_outside_compilation && + *src_outside_compilation != *dst_outside_compilation) { + edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); + VLOG(4) << "Oc -> oc edge: " << e->DebugString(); + } + } + + // Remove the edge from host to outside compilation. Add a placeholder as + // outside compilation node input. + std::map, Node*> placeholders; + for (int i = 0; i < edges.size(); i++) { + Node* dst = g->FindNodeId(edges[i].dst_node_id); + const Edge* e; + TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); + Node* src = e->src(); + int src_output = e->src_output(), dst_input = e->dst_input(); + g->RemoveEdge(e); + + // Find or create placeholder node. + string new_name = + absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); + auto placeholder_index = std::make_pair(src->name(), src_output); + auto iter = placeholders.find(placeholder_index); + Node* placeholder_node; + if (iter == placeholders.end()) { + NodeDefBuilder placeholder_builder(new_name, "Placeholder"); + placeholder_builder.Attr("dtype", src->output_type(src_output)); + string outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), + outside_compilation_attr_name, + &outside_compilation_attr)); + placeholder_builder.Attr(outside_compilation_attr_name, + outside_compilation_attr); + placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName, + src->name()); + placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName, + src_output); + NodeDef placeholder_def; + TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def)); + Status s; + placeholder_node = g->AddNode(placeholder_def, &s); + TF_RETURN_IF_ERROR(s); + placeholders[placeholder_index] = placeholder_node; + } else { + placeholder_node = iter->second; + } + g->AddEdge(placeholder_node, 0, dst, dst_input); + + // Replace `e->dst()` because its input node changed. + NodeDef new_def = dst->def(); + *new_def.mutable_input(dst_input) = placeholder_node->name(); + TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def)); + + // Other edge in `edges` might have `e->dst()` as src or dst + // node. Before removing `e->dst()`, replace those edges with + // corresponding edges for `dst_replace_node`. + for (int j = i + 1; j < edges.size(); j++) { + if (edges[j].dst_node_id == edges[i].dst_node_id) { + edges[j].dst_node_id = dst_replace_node->id(); + } + } + } + return Status::OK(); +} + +// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessDataEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Gather all outside compilation to outside compilation nodes. + std::vector placeholder_nodes; + for (Node* n : g->nodes()) { + if (n->type_string() == "Placeholder" && + HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) { + placeholder_nodes.push_back(n); + } + } + + // Remove the placeholder nodes, and reconnect original edge. + auto node_name_index = g->BuildNodeNameIndex(); + for (auto n : placeholder_nodes) { + string node_name; + int node_src_output; + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); + TF_RETURN_IF_ERROR(GetNodeAttr( + n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output)); + auto iter = node_name_index.find(node_name); + if (iter == node_name_index.end()) { + return errors::Internal( + "Cannot find original node for oc -> host placeholder node ", + node_name); + } + + // Change all usage node to use the original node instead. + Node* original_node = iter->second; + std::vector control_edges; + std::vector data_edges; + for (auto e : n->out_edges()) { + if (e->IsControlEdge()) { + control_edges.push_back(e); + } else { + data_edges.push_back({e->dst(), e->src_output(), e->dst_input()}); + } + } + for (const Edge* e : control_edges) { + g->AddControlEdge(original_node, e->dst()); + g->RemoveEdge(e); + } + for (int i = 0; i < data_edges.size(); i++) { + Node* dst = data_edges[i].dst; + NodeDef new_def = dst->def(); + int dst_input = data_edges[i].dst_input; + *new_def.mutable_input(dst_input) = + absl::StrCat(original_node->name(), ":", node_src_output); + TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def)); + + const Edge* edge_to_replace = nullptr; + TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace)); + g->RemoveEdge(edge_to_replace); + g->AddEdge(original_node, node_src_output, replace_node, dst_input); + + // Other edges might have `dst` as dst node. Update those edges with + // `replace_node`. + for (int j = i + 1; j < data_edges.size(); j++) { + if (data_edges[j].dst == dst) { + data_edges[j].dst = replace_node; + } + } + + // Other placeholder node might have `dst` as original node. Update + // `node_name_index` with `replace_node`. + node_name_index[replace_node->name()] = replace_node; + } + + // Remove placeholder node. + g->RemoveNode(n); + } + return Status::OK(); +} + +// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of +// `PostprocessEdgesBetweenOutsideCompilations` for details. +Status PostprocessControlEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + auto node_name_index = g->BuildNodeNameIndex(); + + // Reconnect outside compilation to outside compilation control edge. + for (Node* n : g->nodes()) { + std::vector control_deps; + Status s = + GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, + &control_deps); + if (!s.ok()) { + if (s.code() != error::NOT_FOUND) { + return s; + } else { + continue; + } + } else { + n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); + for (const string& control_input : control_deps) { + auto iter = node_name_index.find(control_input); + if (iter == node_name_index.end()) { + return errors::Internal("Cannot find original node for ", + control_input); + } + g->AddControlEdge(iter->second, n); + } + } + } + return Status::OK(); +} +} // namespace + +const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; + +const char kXlaConnectedToOtherXlaComputationAttrName[] = + "_xla_connected_to_other_xla_computation"; +const char kXlaConnectedFromOtherXlaComputationAttrName[] = + "_xla_connected_from_other_xla_computation"; +const char kXlaControlDependenciesAttrName[] = "_xla_control_dependencies"; +const char kBridgeSourceNodeAttrName[] = "_xla_bridge_src"; +const char kOutsideCompilationToHostOriginalNodeAttrName[] = + "_xla_oc_to_host_node_name"; +const char kOutsideCompilationToHostSrcOutputAttrName[] = + "_xla_oc_to_host_src_output"; +const char kHostToOutsideCompilationOriginalNodeAttrName[] = + "_xla_host_to_oc_node_name"; +const char kHostToOutsideCompilationSrcOutputAttrName[] = + "_xla_host_to_oc_src_output"; +const char kXlaConnectedToXlaComputationAttrName[] = + "_xla_connected_to_xla_computation"; +const char kXlaConnectedFromXlaComputationAttrName[] = + "_xla_connected_from_xla_computation"; +const char kOutsideCompilationOriginalNodeAttrName[] = + "_xla_oc_to_oc_node_name"; +const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output"; +const char kXlaControlDependenciesWithinXlaClusterAttrName[] = + "_xla_control_dependencies_within_xla_cluster"; + +Status PerformStaticShapeInferenceBeforeEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + // Find all outside compilation to XLA computation data edges. + std::unordered_set outside_compilation_send_nodes; + for (auto e : g->edges()) { + if (e->IsControlEdge()) { + continue; + } + + auto src_computation = GetStringAttr(*e->src(), xla_computation_attr_name); + auto dst_computation = GetStringAttr(*e->dst(), xla_computation_attr_name); + if (!src_computation || !dst_computation || + *src_computation != *dst_computation) { + continue; + } + + auto src_outside_compilation = + GetStringAttr(*e->src(), outside_compilation_attr_name); + auto dst_outside_compilation = + GetStringAttr(*e->dst(), outside_compilation_attr_name); + if (src_outside_compilation && !dst_outside_compilation) { + outside_compilation_send_nodes.insert(e->src()); + } + } + + // Perform shape inference. + std::map arg_shapes; + GraphShapeInfo shape_info; + TF_RETURN_IF_ERROR( + InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); + + // Add attribute for output shapes. + for (Node* n : outside_compilation_send_nodes) { + auto iter = shape_info.find(n->name()); + if (iter == shape_info.end()) { + continue; + } + + std::vector output_shapes; + std::transform(iter->second.begin(), iter->second.end(), + std::back_inserter(output_shapes), + [](const InferredShape& inferred_shape) { + return inferred_shape.shape; + }); + n->AddAttr(kXlaInferredShapesAttrName, output_shapes); + } + + return Status::OK(); +} + +Status PreprocessForEncapsulation(Graph* g, + const string& xla_computation_attr_name, + const string& outside_compilation_attr_name) { + TF_RETURN_IF_ERROR(ProcessControlEdges(g, xla_computation_attr_name, + outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(ProcessXlaToXlaDataEdges(g, xla_computation_attr_name, + outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(ProcessDataEdgeBetweenOutsideCompilationAndHostComputation( + g, xla_computation_attr_name, outside_compilation_attr_name)); + return Status::OK(); +} + +Status PostprocessForEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name, + const std::unordered_map& clusters) { + // The `node` pointer in `XlaClusterInfo` might be invalidated in step 1/2, + // but the node name won't change. Record cluster node name for + // `AddControlDependencies`. + std::unordered_map cluster_node_names; + for (const auto& iter : clusters) { + cluster_node_names[iter.first] = iter.second.node->name(); + } + + TF_RETURN_IF_ERROR( + RemovePlaceholderBetweenOutsideCompilationAndHostComputation(g)); + TF_RETURN_IF_ERROR(RemoveIdentityBetweenDifferentXlaComputation(g)); + TF_RETURN_IF_ERROR(AddControlDependencies(g, cluster_node_names)); + return Status::OK(); +} + +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + // Remove edges from source node to outside compilation nodes, and edges + // from outside compilation nodes to sink node. + std::vector edges_to_remove; + for (const Edge* e : g->source_node()->out_edges()) { + if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (const Edge* e : g->sink_node()->in_edges()) { + if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) { + edges_to_remove.push_back(e); + } + } + for (auto e : edges_to_remove) { + g->RemoveEdge(e); + } + + TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name) { + TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( + g, outside_compilation_attr_name)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e363bc5754ac395bae262dc67a780a0173efaf5e --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -0,0 +1,210 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains some utility functions for encapsulating XLA computation +// in host graph and encapsulating outside compilation in XLA computation. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Attribute marking output tensor shapes inferred by XLA. Attribute value is +// a list of PartialTensorShape objects. +extern const char kXlaInferredShapesAttrName[]; + +// Infer output shapes for outside compilation nodes which have output data +// edges to XLA computation nodes. These shapes will be used later by XLA +// compiler as output shapes of the outside compilation's XlaHostCompute op. +// XLA computation nodes will be mark by attr `xla_computation_attr_name`; +// outside compilation nodes will be marked by both attr +// `xla_computation_attr_name` and `outside_compilation_attr_name`. +// +// Those outside compilation nodes will be marked with attribute +// `kXlaInferredShapesAttrName`. +// +// We have to perform shape inference before encapsulation because after +// encapsulation, some nodes will be encapsulated into function call, and shape +// inference does not handle function call at the moment. +Status PerformStaticShapeInferenceBeforeEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name); + +// Attribute indicating that some ops in other XLA computation has control +// dependency on this node. Attribute value will be a list of string (XLA +// computation names). +extern const char kXlaConnectedToOtherXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// other XLA computation. Attribute value will be a list of string (XLA +// computation names). +extern const char kXlaConnectedFromOtherXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes. Attribute value will be a list of string (node names). +extern const char kXlaControlDependenciesAttrName[]; + +// Attribute indicating that this is an Identity node added to act as a bridge +// between different XLA computations. Attribute value will be string (source +// node name). +extern const char kBridgeSourceNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationToHostOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationToHostSrcOutputAttrName[]; + +// Attribute indicating that some ops in this node's XLA computation has control +// dependency on this node. Attribute value will always be "true". +extern const char kXlaConnectedToXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// this node's XLA computation. Attribute value will always be "true". +extern const char kXlaConnectedFromXlaComputationAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an host node. Attribute value will be string +// (original input node name). +extern const char kHostToOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for a host node. Attribute value will be int (src_output +// for original edge). +extern const char kHostToOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes within the same XLA cluster. Attribute value will be a list of string +// (node names). +extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; + +// Preprocesses edges between different XLA clusters for encapsulation. It will +// perform the following operations in order: +// +// 1a. For control edges between outside compilation and another XLA +// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName +// = XLA computation node name" to the outside compilation node. +// 1b. For control edges between different outside compilations (in different +// XLA computations), remove the edge and add attr +// "kXlaControlDependenciesAttrName = src node name" to dst node. +// 1c. For control edges between outside compilation and host computation, +// remove the edge and add attr "kXlaControlDependenciesAttrName = src node +// name" to dst node. +// 2. For data edges between different XLA computations, if either src or dst +// is outside compilation, add an Identity node in between the edge. The +// identity node will have attr kBridgeSourceNodeAttrName. +// 3. For data edges between outside compilation and host computation, remove +// the edge and create a Placeholder node as dst node's input. +Status PreprocessForEncapsulation(Graph* g, + const string& xla_computation_attr_name, + const string& outside_compilation_attr_name); + +// Information for XLA computation. +struct XlaClusterInfo { + // Add an explicitly-defined default constructor for this class. + // + // The compiler may delete the default constructor here because + // host_compute_core is a const member whose type (std::map) doesn't + // necessarily have a user provided constructor -- while libc++ and + // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at + // least >= 7.3 does not. See also c++11 [class.ctor] p5. + // + // TODO(klimek): In c++17 we'll be able to initialize host_compute_core + // without losing aggregate initialization, which allows us to get rid of + // the constructor definitions again. + XlaClusterInfo() {} + XlaClusterInfo(const string& cluster_name, + const NameAttrList& func_name_attrs, Node* node, + const std::map& host_compute_core) + : cluster_name(cluster_name), + func_name_attrs(func_name_attrs), + node(node), + host_compute_core(host_compute_core) {} + // XLA cluster name. It might be different from `func_name`. + const string cluster_name; + // Name and attributes of XLA computation function. + const NameAttrList func_name_attrs; + // The XLA computation node in the graph. + Node* node; + // A mapping from outside compilation cluster name to its device assignment. + const std::map host_compute_core; +}; + +// Postprocesses edges between different XLA clusters for encapsulation. This +// function reverts what `PreprocessForEncapsulation` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between outside compilation and host computation +// (created in `PreprocessForEncapsulation` step 3). +// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. +// 3a. Reconnect control edges between outside compilation and another XLA +// computation (marked by `PreprocessForEncapsulation` step 1a). +// 3b. Reconnect control edges between different outside compilations (marked by +// `PreprocessForEncapsulation` step 1b). +// 3c. Reconnect control edges between outside compilation and host computation +// (marked by `PreprocessForEncapsulation` step 1c). +Status PostprocessForEncapsulation( + Graph* g, const string& xla_computation_attr_name, + const string& outside_compilation_attr_name, + const std::unordered_map& clusters); + +// Preprocesses edges within the same XLA cluster. It will perform the following +// operations in order: +// +// 0. Remove edges from source node to outside compilation nodes, and edges +// from outside compilation nodes to sink node. +// 1a. For edges between different outside compilation clusters, remove the edge +// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node +// name" to dst node. +// 1b. For control edges between outside compilation and its XLA computation, +// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the +// outside compilation node. +// 2. For data edges between different outside compilations, remove the edge +// and create a Placeholder node as dst node's input. +Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); + +// Postprocesses edges within the same XLA cluster. This function reverts what +// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between different outside compilations (created +// in `PreprocessEdgesBetweenOutsideCompilations` step 2). +// 2a. Reconnect control edges between different outside compilations (marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1a). +// Notice that control edges marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. +// They are handled in `RewriteOutsideCompilationSubgraphFn`. +Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b8b49cb92f3e453883a8e64e12ce3748a5173f6 --- /dev/null +++ b/tensorflow/compiler/jit/encapsulate_util_test.cc @@ -0,0 +1,394 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/encapsulate_util.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) { + // Build the graph: + // "add" = "const_0" + "const_1" + // "identity" = "add" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {2}); + Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {2}); + Output add = ops::Add(s.WithOpName("add"), const_0, const_1); + Output identity = ops::Identity(s.WithOpName("identity"), add); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + + // "add" node is outside compilation node, "identity" node is XLA node. + auto node_index = g.BuildNodeNameIndex(); + Node *add_node = node_index["add"], *identity_node = node_index["identity"]; + add_node->AddAttr("_xla", "cluster"); + add_node->AddAttr("_oc", "cluster"); + identity_node->AddAttr("_xla", "cluster"); + TF_CHECK_OK( + PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc")); + + // Check that only "add" node now has _xla_inferred_shapes attr. + std::vector nodes_with_inferred_shape; + for (Node *n : g.nodes()) { + if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) { + nodes_with_inferred_shape.push_back(n); + } + } + EXPECT_EQ(nodes_with_inferred_shape.size(), 1); + EXPECT_EQ(nodes_with_inferred_shape[0], add_node); + std::vector output_shapes; + TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName, + &output_shapes)); + EXPECT_EQ(output_shapes.size(), 1); + TensorShapeProto shape_proto; + output_shapes[0].AsProto(&shape_proto); + EXPECT_EQ(shape_proto.dim_size(), 1); + EXPECT_EQ(shape_proto.dim(0).size(), 2); +} + +TEST(PreprocessForEncapsulationTest, ControlEdges) { + // Build the graph: + // "const_0" and "const_1" in host computation + // "add" = "const_0" + "const_1" in XLA computation 0 + // "identity0" = "add" in XLA computation 0 & outside compilation 0 + // "identity1" = "identity0" in XLA computation 0 + // "identity2" = "identity1" in host computation + // "identity3" = "identity2" in XLA computation 1 + // "identity4" = "identity3" in XLA computation 1 & outside compilation 1 + // "identity5" = "identity4" in XLA computation 1 + // "identity6" = "identity5" in host computation + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); + Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); + Output add = ops::Add(s.WithOpName("add"), const_0, const_1); + Output identity0 = ops::Identity(s.WithOpName("identity0"), add); + Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); + Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); + Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3); + Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + auto node_index = g.BuildNodeNameIndex(); + + // Set XLA computation/outside compilation attr, and add control edges. + Node *const0_node = node_index["const_0"], *add_node = node_index["add"], + *identity0_node = node_index["identity0"], + *identity1_node = node_index["identity1"], + *identity2_node = node_index["identity2"], + *identity3_node = node_index["identity3"], + *identity4_node = node_index["identity4"], + *identity5_node = node_index["identity5"]; + add_node->AddAttr("_xla", "0"); + identity0_node->AddAttr("_xla", "0"); + identity0_node->AddAttr("_oc", "0"); + identity1_node->AddAttr("_xla", "0"); + identity3_node->AddAttr("_xla", "1"); + identity4_node->AddAttr("_xla", "1"); + identity4_node->AddAttr("_oc", "0"); + identity5_node->AddAttr("_xla", "1"); + // Case 1a: control edges between outside compilation and another XLA + // computation. + g.AddControlEdge(identity0_node, identity3_node); + g.AddControlEdge(identity1_node, identity4_node); + // Case 1b: control edges between different outside compilations. + g.AddControlEdge(identity0_node, identity4_node); + // Case 1c: control edges between outside compilation and host computation. + g.AddControlEdge(const0_node, identity0_node); + g.AddControlEdge(identity0_node, identity2_node); + + TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); + + // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name" + // to the outside compilation node. + std::vector attr; + TF_CHECK_OK(GetNodeAttr(identity0_node->def(), + kXlaConnectedToOtherXlaComputationAttrName, &attr)); + EXPECT_EQ(attr.size(), 1); + EXPECT_EQ(attr[0], "1"); + attr.clear(); + TF_CHECK_OK(GetNodeAttr(identity4_node->def(), + kXlaConnectedFromOtherXlaComputationAttrName, &attr)); + EXPECT_EQ(attr.size(), 1); + EXPECT_EQ(attr[0], "0"); + // Case 1b: add attr "_xla_control_deps = src node name" to dst node. + attr.clear(); + TF_CHECK_OK(GetNodeAttr(identity4_node->def(), + kXlaControlDependenciesAttrName, &attr)); + EXPECT_EQ(attr.size(), 1); + EXPECT_EQ(attr[0], "identity0"); + // Case 1c: add attr "_xla_control_deps = src node name" to dst node. + attr.clear(); + TF_CHECK_OK(GetNodeAttr(identity0_node->def(), + kXlaControlDependenciesAttrName, &attr)); + EXPECT_EQ(attr.size(), 1); + EXPECT_EQ(attr[0], "const_0"); + attr.clear(); + TF_CHECK_OK(GetNodeAttr(identity2_node->def(), + kXlaControlDependenciesAttrName, &attr)); + EXPECT_EQ(attr.size(), 1); + EXPECT_EQ(attr[0], "identity0"); +} + +TEST(PreprocessForEncapsulationTest, DataEdges) { + // Build the graph: + // "const_0" and "const_1" in host computation + // "identityn0" = ("const_0", "const_1") in host computation 0 + // "add0" = "const_0" + "const_1" in XLA computation 0 + // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 + // "identity0" = "add1" in XLA computation 0 + // "add2" = "add1" + "identity0" in host computation + // "add3" = "add1" + "add2" in XLA computation 1 + // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0 + // "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 & + // outside compilation 0 + // "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 & + // outside compilation 0 + // "identity1" = "add4" in XLA computation 1 + // "identity2" = "identity1" in host computation + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); + Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); + auto identityn0 = + ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1}); + Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); + Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); + Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); + Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); + Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); + Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); + Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]); + auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"), + {identityn0[0], identityn0[1]}); + Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); + Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + auto node_index = g.BuildNodeNameIndex(); + + // Set XLA computation/outside compilation attr. + Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], + *identity0_node = node_index["identity0"], + *add3_node = node_index["add3"], *add4_node = node_index["add4"], + *add5_node = node_index["add5"], + *identityn1_node = node_index["identityn_1"], + *identity1_node = node_index["identity1"]; + add0_node->AddAttr("_xla", "0"); + add1_node->AddAttr("_xla", "0"); + add1_node->AddAttr("_oc", "0"); + identity0_node->AddAttr("_xla", "0"); + add3_node->AddAttr("_xla", "1"); + add4_node->AddAttr("_xla", "1"); + add4_node->AddAttr("_oc", "0"); + add5_node->AddAttr("_xla", "1"); + add5_node->AddAttr("_oc", "0"); + identityn1_node->AddAttr("_xla", "1"); + identityn1_node->AddAttr("_oc", "0"); + identity1_node->AddAttr("_xla", "1"); + + TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); + + // Check input nodes for related data edges. + node_index = g.BuildNodeNameIndex(); + // Step 2: add an Identity node between different XLA computations. + Node *bridge_add1_add3 = node_index["bridge_add1_add3"]; + EXPECT_NE(bridge_add1_add3, nullptr); + string str; + TF_CHECK_OK( + GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str)); + EXPECT_EQ(str, "add1"); + Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"]; + EXPECT_NE(bridge_identity0_add4, nullptr); + // Step 3: add placeholder for edges between host computation and outside + // compilation. + EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0"); + Node *add1_oc_to_host_placeholder = + node_index["add1_oc_to_host_placeholder_0"]; + TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), + kOutsideCompilationToHostOriginalNodeAttrName, &str)); + EXPECT_EQ(str, "add1"); + int i; + TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), + kOutsideCompilationToHostSrcOutputAttrName, &i)); + EXPECT_EQ(i, 0); + add4_node = node_index["add4"]; + ASSERT_NE(add4_node, nullptr); + EXPECT_EQ(add4_node->def().input(0), + "bridge_identity0_add4_host_to_oc_placeholder_0"); + Node *identity0_host_to_oc_placeholder = + node_index["bridge_identity0_add4_host_to_oc_placeholder_0"]; + TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), + kHostToOutsideCompilationOriginalNodeAttrName, &str)); + EXPECT_EQ(str, "bridge_identity0_add4"); + TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), + kHostToOutsideCompilationSrcOutputAttrName, &i)); + EXPECT_EQ(i, 0); + + // Check different placeholder nodes are created for different src_output. + Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"], + *placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"]; + EXPECT_NE(placeholder0, nullptr); + EXPECT_NE(placeholder1, nullptr); + // Check we only have 2 placeholder nodes created for "identityn_0". + int placeholder_count = 0; + for (Node *n : g.nodes()) { + if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) { + string attr; + TF_CHECK_OK(GetNodeAttr( + n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr)); + if (attr == "identityn_0") { + ++placeholder_count; + } + } + } + EXPECT_EQ(placeholder_count, 2); +} + +TEST(PostprocessForEncapsulationTest, ControlEdges) { + // Build the graph: + // "const0" + // "identity0" = "const0" (XLA computation 0) + // "identity1" = "identity0" + // "identity2" = "identity1" (XLA computation 1) + // "identity3" = "identity2" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); + Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); + Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); + Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + auto node_index = g.BuildNodeNameIndex(); + + // Set XLA computation/outside compilation attr, and add control edges. + Node *const0_node = node_index["const0"], + *identity0_node = node_index["identity0"], + *identity1_node = node_index["identity1"], + *identity2_node = node_index["identity2"], + *identity3_node = node_index["identity3"]; + identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName, + std::vector{"0"}); + identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName, + std::vector{"1"}); + identity3_node->AddAttr(kXlaControlDependenciesAttrName, + std::vector{"const0", "identity1"}); + + std::unordered_map clusters; + clusters["0"].node = identity0_node; + clusters["1"].node = identity2_node; + TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); + + // Case 3a: we have control edge identity0 -> identity1, and identity1 -> + // identity2. + bool edge_identity0_identity1 = false, edge_identity1_identity2 = false; + for (const Edge *e : g.edges()) { + if (!e->IsControlEdge()) { + continue; + } + if (e->src() == identity0_node && e->dst() == identity1_node) { + edge_identity0_identity1 = true; + } else if (e->src() == identity1_node && e->dst() == identity2_node) { + edge_identity1_identity2 = true; + } + } + EXPECT_TRUE(edge_identity0_identity1); + EXPECT_TRUE(edge_identity1_identity2); + // Case 3b: we have control edge const0 -> identity3, and identity1 -> + // identity3. + bool edge_const0_identity3 = false, edge_identity1_identity3 = false; + for (const Edge *e : g.edges()) { + if (!e->IsControlEdge()) { + continue; + } + if (e->src() == const0_node && e->dst() == identity3_node) { + edge_const0_identity3 = true; + } else if (e->src() == identity1_node && e->dst() == identity3_node) { + edge_identity1_identity3 = true; + } + } + EXPECT_TRUE(edge_const0_identity3); + EXPECT_TRUE(edge_identity1_identity3); +} + +TEST(PostprocessForEncapsulationTest, DataEdges) { + // Build the graph: + // "const0" in outside compilation "0" + // "placeholder0" (for "const0") in host computation + // "add0" = "placeholder0" + "placeholder0" in host computation + // "placeholder1" (for "add0") in outside compilation 1 + // "add1" = "placeholder1" + "placeholder1" in outside compilation 1 + // + // "bridge" = "placeholder0" in host computation + // "placeholder2" (for "bridge") in outside compilation 1 + // "add2" = "placeholder2" + "placeholder2" in outside compilation 1 + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {}); + Output placeholder0 = + ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32); + Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0); + Output placeholder1 = + ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32); + Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1); + Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0); + Output placeholder2 = + ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32); + Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2); + Graph g(OpRegistry::Global()); + TF_CHECK_OK(s.ToGraph(&g)); + auto node_index = g.BuildNodeNameIndex(); + + // Set related attributes. + Node *placeholder0_node = node_index["placeholder0"]; + placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName, + "const0"); + placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0); + Node *placeholder1_node = node_index["placeholder1"]; + placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, + "add0"); + placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); + Node *bridge_node = node_index["bridge"]; + bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0"); + Node *placeholder2_node = node_index["placeholder2"]; + placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName, + "bridge"); + placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0); + + std::unordered_map clusters; + TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters)); + + // Result graph should be: + // "add0" = "const0" + "const0" + // "add1" = "add0" + "add0" + // "add2" = "const0" + "const0" + node_index = g.BuildNodeNameIndex(); + EXPECT_EQ(node_index.size(), 6); + EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0"); + EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0"); + EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0"); + EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0"); + EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0"); + EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0"); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 2ce6fa73fc448ca83fa392aa909cb385453eb8b6..d334100aa4a915a87fb05d371e0e3379a7ee05f2 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && e->dst()->type_string() != kXlaClusterOutput) { return errors::InvalidArgument( - "Undeclared output of XLA computation. A common cause of this error " - "is variable initializers that depend on the XLA computation. Edge: ", + "Undeclared output of XLA computation. Some common causes of this " + "error are: 1) variable initializers that depend on the XLA " + "computation; 2) gradient computations that depend on the XLA " + "computation, which can be mitigated by moving gradient computations " + "inside XLA computation. Offending edge: ", e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", e->dst_input()); } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 22531a4acea3f130175c7cb2e03fcb7570926094..192e1c7b32467d80cef6ff61a1c7078f8dea9dfb 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -256,7 +256,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - std::unordered_map index = BuildNodeIndex(*graph); + std::unordered_map index = graph->BuildNodeNameIndex(); string function = index.at("launch0")->type_string(); // Tests the outer graph is as expected. @@ -291,7 +291,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) { // function. Encapsulation should be deterministic to avoid recompilation. TF_ASSERT_OK( EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map index_copy = BuildNodeIndex(*graph_copy); + std::unordered_map index_copy = + graph_copy->BuildNodeNameIndex(); string function_copy = index_copy.at("launch0")->type_string(); EXPECT_EQ(function, function_copy); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3c7e2f89be9b37b51a633dabb099969c181013f --- /dev/null +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -0,0 +1,941 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// Add a key placeholder node to the graph. The key placeholder node will be +// used as input for XlaRecvAtHost/XlaSendFromHost nodes. +xla::StatusOr AddHostComputeKeyPlaceholder( + const string& xla_cluster_name, Graph* g) { + NodeDef key_def; + NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"), + "Placeholder"); + builder.Attr("dtype", DT_STRING); + builder.Attr("shape", PartialTensorShape({2})); + builder.Attr("_host_compute_call_node", xla_cluster_name); + Status s = builder.Finalize(&key_def); + if (!s.ok()) return s; + + Node* n = g->AddNode(key_def, &s); + if (!s.ok()) return s; + return n; +} + +// Returns if the node is a XLA computation key placeholder. +bool IsKeyPlaceholderNode(const Node& n) { + return n.type_string() == "Placeholder" && + absl::EndsWith(n.name(), "_key_placeholder"); +} + +// Returns nodes with given type. +std::vector GatherNodesWithType(const Graph& g, const string& type) { + std::vector result; + for (Node* n : g.nodes()) { + if (n->type_string() == type) { + result.push_back(n); + } + } + return result; +} + +// Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`. +Status GetArgDataTypes(const std::vector& arg_nodes, + std::vector* recv_at_host_dtypes) { + recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID); + for (auto* n : arg_nodes) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); + (*recv_at_host_dtypes)[index] = dtype; + } + for (int i = 0; i < recv_at_host_dtypes->size(); i++) { + if ((*recv_at_host_dtypes)[i] == DT_INVALID) { + return errors::Internal("Cannot get datatype for input ", i); + } + } + return Status::OK(); +} + +// Builds XlaRecvAtHost node. +xla::StatusOr BuildRecvAtHostNode( + Graph* g, const string& oc_cluster_name, + const std::vector& recv_at_host_dtypes, Node* key_placeholder) { + NodeDefBuilder recv_at_host_builder( + absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"), + "_XlaRecvAtHost"); + NodeDef recv_at_host_def; + recv_at_host_builder.Attr("Toutputs", recv_at_host_dtypes); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + recv_at_host_builder.Attr("device_ordinal", 0); + recv_at_host_builder.Attr( + "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + recv_at_host_builder.Input(key_placeholder->name(), 0, DT_STRING); + TF_RETURN_IF_ERROR(recv_at_host_builder.Finalize(&recv_at_host_def)); + Status s; + Node* recv_at_host_node = g->AddNode(recv_at_host_def, &s); + TF_RETURN_IF_ERROR(s); + return recv_at_host_node; +} + +// Builds XlaRecvAtHost node, and replaces all _Arg nodes with it. +xla::StatusOr ReplaceArgNodesWithRecvAtHostNode( + Graph* g, const string& oc_cluster_name, + std::vector* recv_at_host_dtypes, Node* key_placeholder) { + // TODO(b/77601805): use out nodes for source node, instead of traversing all + // nodes. + std::vector arg_nodes = GatherNodesWithType(*g, "_Arg"); + TF_RETURN_IF_ERROR(GetArgDataTypes(arg_nodes, recv_at_host_dtypes)); + TF_ASSIGN_OR_RETURN( + Node * recv_at_host_node, + BuildRecvAtHostNode(g, oc_cluster_name, *recv_at_host_dtypes, + key_placeholder)); + for (auto* n : arg_nodes) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + // Record out edges and remove `n` before adding those edges to RecvAtHost. + // This is to avoid multiple producers. + std::vector out_edge_info; + for (auto edge : n->out_edges()) { + out_edge_info.push_back( + {edge->dst(), edge->src_output(), edge->dst_input()}); + } + g->RemoveNode(n); + for (const OutEdgeInfo& edge : out_edge_info) { + if (edge.dst_input == Graph::kControlSlot) { + g->AddControlEdge(recv_at_host_node, edge.dst); + } else { + g->AddEdge(recv_at_host_node, index, edge.dst, edge.dst_input); + } + } + + // Rewrite dst nodes because their input changed. + for (int i = 0; i < out_edge_info.size(); i++) { + const OutEdgeInfo edge = out_edge_info[i]; + if (edge.dst_input == Graph::kControlSlot) { + continue; + } + + Node* dst = edge.dst; + NodeDef new_def = dst->def(); + *new_def.mutable_input(edge.dst_input) = + absl::StrCat(recv_at_host_node->name(), ":", index); + TF_ASSIGN_OR_RETURN(Node * dst_replace, ReplaceNode(g, dst, new_def)); + + // Other edges might have `dst` as dst node as well. Update those edges + // with `dst_replace`. + for (int j = i + 1; j < out_edge_info.size(); j++) { + if (out_edge_info[j].dst == dst) { + out_edge_info[j].dst = dst_replace; + } + } + } + } + g->AddEdge(key_placeholder, 0, recv_at_host_node, 0); + return recv_at_host_node; +} + +// Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`. +Status GetRetDataTypes(const std::vector& ret_nodes, + std::vector* send_from_host_dtypes) { + send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID); + for (auto* n : ret_nodes) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); + (*send_from_host_dtypes)[index] = dtype; + } + for (int i = 0; i < send_from_host_dtypes->size(); i++) { + if ((*send_from_host_dtypes)[i] == DT_INVALID) { + return errors::Internal("Cannot get datatype for output ", i); + } + } + return Status::OK(); +} + +// Builds XlaSendFromHost node. +xla::StatusOr BuildSendFromHostNode( + Graph* g, const string& oc_cluster_name, + const std::vector& ret_nodes, + const std::vector& send_from_host_dtypes, Node* key_placeholder) { + NodeDefBuilder send_from_host_builder( + absl::StrCat("outside_compilation_", oc_cluster_name, "_send"), + "_XlaSendFromHost"); + NodeDef send_from_host_def; + send_from_host_builder.Attr("Tinputs", send_from_host_dtypes); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + send_from_host_builder.Attr("device_ordinal", 0); + send_from_host_builder.Attr( + "key", absl::StrCat("host_compute_channel_", oc_cluster_name)); + std::vector inputs(send_from_host_dtypes.size()); + for (auto* n : ret_nodes) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + if (index < 0 || index >= send_from_host_dtypes.size()) { + return errors::Internal("Invalid _Retval index: ", index); + } + for (auto edge : n->in_edges()) { + inputs[index] = + NodeDefBuilder::NodeOut{edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())}; + } + } + send_from_host_builder.Input(inputs); + send_from_host_builder.Input(key_placeholder->name(), 0, DT_STRING); + TF_RETURN_IF_ERROR(send_from_host_builder.Finalize(&send_from_host_def)); + Status s; + Node* send_from_host_node = g->AddNode(send_from_host_def, &s); + TF_RETURN_IF_ERROR(s); + return send_from_host_node; +} + +// Builds XlaSendFromHost node, and replaces all _Retval nodes with it. +xla::StatusOr ReplaceRetNodesWithSendFromHostNode( + Graph* g, const string& oc_cluster_name, + std::vector* send_from_host_dtypes, Node* key_placeholder) { + // TODO(b/77601805): use in nodes for sink node, instead of traversing all + // nodes. + std::vector ret_nodes = GatherNodesWithType(*g, "_Retval"); + TF_RETURN_IF_ERROR(GetRetDataTypes(ret_nodes, send_from_host_dtypes)); + TF_ASSIGN_OR_RETURN( + Node * send_from_host_node, + BuildSendFromHostNode(g, oc_cluster_name, ret_nodes, + *send_from_host_dtypes, key_placeholder)); + for (auto* n : ret_nodes) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + for (auto edge : n->in_edges()) { + if (edge->src_output() == Graph::kControlSlot) { + g->AddControlEdge(edge->src(), send_from_host_node); + } else { + g->AddEdge(edge->src(), edge->src_output(), send_from_host_node, index); + } + } + g->RemoveNode(n); + } + g->AddEdge(key_placeholder, 0, send_from_host_node, + send_from_host_dtypes->size()); + return send_from_host_node; +} + +// Returns input shapes (excluding key placeholder) for `send_from_host_node` +// if they are all fully defined; absl::nullopt otherwise. +absl::optional> GetInferredInputShapes( + int num_inputs, Node* send_from_host_node) { + std::vector results(num_inputs); + for (int i = 0; i < num_inputs; i++) { + const Edge* e; + if (!send_from_host_node->input_edge(i, &e).ok()) { + return absl::nullopt; + } + + std::vector shapes; + if (!GetNodeAttr(e->src()->attrs(), kXlaInferredShapesAttrName, &shapes) + .ok()) { + return absl::nullopt; + } + + const PartialTensorShape shape = shapes[e->src_output()]; + if (!shape.IsFullyDefined()) { + return absl::nullopt; + } + + results[e->dst_input()] = shape; + } + return results; +} + +// Builds XlaHostCompute NodeDef from the outside compilation call node. +xla::StatusOr BuildXlaHostComputeNodeDef( + const Node* call_node, const std::map& host_compute_core) { + string original_oc_name; + TF_RETURN_IF_ERROR(GetNodeAttr( + call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); + NodeDefBuilder host_compute_builder( + absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"), + "XlaHostCompute"); + + // Copy all attributes. + for (auto attr : call_node->attrs()) { + host_compute_builder.Attr(attr.first, attr.second); + } + + // Populate tpu_core assignment. + const auto iter = host_compute_core.find(original_oc_name); + if (iter != host_compute_core.end()) { + int core = iter->second; + host_compute_builder.Attr("tpu_core", core); + } + + // Populate inputs. + std::vector input_dtypes; + TF_RETURN_IF_ERROR(GetNodeAttr(call_node->attrs(), "Tinputs", &input_dtypes)); + std::vector inputs(input_dtypes.size()); + for (auto e : call_node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + + if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) { + return errors::Internal("Invalid dst_input: ", e->dst_input()); + } + inputs[e->dst_input()] = NodeDefBuilder::NodeOut{ + e->src()->name(), e->src_output(), input_dtypes[e->dst_input()]}; + } + host_compute_builder.Input(inputs); + + NodeDef new_def; + TF_RETURN_IF_ERROR(host_compute_builder.Finalize(&new_def)); + return new_def; +} + +// Replace outside compilation function call node with XlaHostCompute node. +// If the function call node has no input/output edges, we will just remove it +// and not create a XlaHostCompute node. +Status ReplaceOrRemoveOutsideCompilationCallNode( + Graph* g, Node* call_node, const std::map& host_compute_core) { + // If the function call node has no input/output edges, just remove it. + bool has_edge = false; + for (auto e : call_node->in_edges()) { + if (!e->IsControlEdge() || e->src() != g->source_node()) { + has_edge = true; + break; + } + } + for (auto e : call_node->out_edges()) { + if (!e->IsControlEdge() || e->dst() != g->sink_node()) { + has_edge = true; + break; + } + } + if (!has_edge) { + VLOG(4) << "Did not add HostCompute node for " << call_node->DebugString(); + g->RemoveNode(call_node); + return Status::OK(); + } + + // Build XlaHostCompute NodeDef. + TF_ASSIGN_OR_RETURN(NodeDef node_def, + BuildXlaHostComputeNodeDef(call_node, host_compute_core)); + TF_ASSIGN_OR_RETURN(Node * host_compute_node, + ReplaceNode(g, call_node, node_def)); + VLOG(4) << "Added HostCompute node: " << host_compute_node->DebugString(); + + return Status::OK(); +} + +// For an XLA computation, builds host side graph given all outside compilation +// graphs inside it. The host side graph contains: +// 1) a "sequencer" node (we will add control edge between XlaRecvAtHost and +// XlaSendFromHost to this sequencer node, so all outside compilation nodes +// will be executed *before* this sequencer). +// 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will +// replace this node with compilation result node. +// 3) all outside compilation graphs. +Status ConstructHostGraph( + const string& xla_cluster_name, const string& outside_compilation_attr_name, + const std::vector& outside_compilation_host_graphs, + FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { + host_graph->reset(new Graph(fld)); + + // Create sequencer node in host graph. + NodeDefBuilder sequencer_builder(absl::StrCat(xla_cluster_name, "_sequencer"), + "NoOp"); + sequencer_builder.Attr("_xla_host_transfer_sequencer", xla_cluster_name); + NodeDef sequencer_def; + TF_RETURN_IF_ERROR(sequencer_builder.Finalize(&sequencer_def)); + Status s; + Node* sequencer = (*host_graph)->AddNode(sequencer_def, &s); + TF_RETURN_IF_ERROR(s); + + // Create key placeholder in host graph. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name, host_graph->get())); + + // For each outside compilation graph, copy them to host graph with the + // following changes: + // a) Use key_placeholder in host graph instead of its own. + // b) Add control edge from RecvAtHost/SendFromHost to sequencer. + // c) Clear node_def.device(), so device placer won't get confused. + for (const string& host_func : outside_compilation_host_graphs) { + VLOG(4) << "Expanding host graph " << host_func; + FunctionBody* host_fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(host_func), AttrSlice(), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &host_fbody)); + std::unique_ptr host_fbody_deleter(host_fbody); + + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse + // reachable from sink node so all nodes will be copied. + // TODO(b/77601805): consolidate copy graph functions. + FixupSourceAndSinkEdges(host_fbody->graph); + + std::map node_map; + node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); + node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); + Status s; + ReverseDFS( + *host_fbody->graph, /*enter=*/nullptr, + [&](const Node* n) { + if (!s.ok()) { + return; + } + + Node* copy; + if (node_map.find(n) != node_map.end()) { + // Already copied this node. + copy = node_map.at(n); + } else if (IsKeyPlaceholderNode(*n)) { + // Change a). + copy = key_placeholder; + node_map[n] = copy; + } else { + // Copy the node. + NodeDef copy_def = n->def(); + // Change c). + copy_def.clear_device(); + copy = (*host_graph)->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + node_map[n] = copy; + } + + // Only handle input edges. Output edges will be added later as + // its output nodes' input edges. + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + (*host_graph) + ->AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); + } + + // Change b). + if (copy->type_string() == "_XlaRecvAtHost" || + copy->type_string() == "_XlaSendFromHost") { + (*host_graph)->AddControlEdge(copy, sequencer); + } + }, + NodeComparatorID()); + if (!s.ok()) { + return s; + } + } + + // sequencer and key_placeholder might be dead nodes. Prune them if necessary. + // - sequencer should be pruned iff it has no input control edges from + // RecvAtHost/SendFromHost. If it has input control edge, we connect it to + // sink node so it won't be pruned. + // - key_placeholder should be pruned iff there's no RecvAtHost/SendFromHost. + // We don't need to do anything special. + if (!sequencer->in_edges().empty()) { + (*host_graph)->AddControlEdge(sequencer, (*host_graph)->sink_node()); + } + PruneForReverseReachability( + host_graph->get(), + std::unordered_set{(*host_graph)->sink_node()}); + + // Postprocess edges between different outside compilations. + TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations( + host_graph->get(), outside_compilation_attr_name)); + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_host_graph_for_", + xla_cluster_name), + **host_graph, fld); + } + + return Status::OK(); +} + +// Expand XLA computation's outside compilation host side graph into main graph. +// Add a control edge between sequencer node and the XLA computation node. +Status ExpandHostGraphIntoMainGraph(Graph* main_graph, Graph* host_graph, + Node* xla_computation_node) { + // We use ReverseDFS() to copy nodes. Make sure all nodes are reverse + // reachable from sink node so all nodes will be copied. + // TODO(b/77601805): consolidate copy graph functions. + FixupSourceAndSinkEdges(host_graph); + + // Copy all nodes. + std::map node_map; + node_map[host_graph->source_node()] = main_graph->source_node(); + node_map[host_graph->sink_node()] = main_graph->sink_node(); + Status s = Status::OK(); + auto copy_node_fn = [&](const Node* n) { + if (!s.ok()) { + return; + } + + Node* copy; + if (node_map.find(n) != node_map.end()) { + // Already copied this node. + copy = node_map.at(n); + } else { + // Copy the node. + NodeDef copy_def = n->def(); + copy = main_graph->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + node_map[n] = copy; + } + + // Only handle input edges. Output edges will be added later as its output + // nodes' input edges. + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + main_graph->AddEdge(node_map[e->src()], e->src_output(), copy, + e->dst_input()); + } + + // Add control edge from sequencer to XLA computation node. + if (copy->type_string() == "NoOp" && + HasNodeAttr(copy->def(), "_xla_host_transfer_sequencer")) { + main_graph->AddControlEdge(copy, xla_computation_node); + } + }; + ReverseDFS(*host_graph, /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); + return s; +} + +// Rewrites shape inference graph for outside compilation. +// 1. If the outside compilation is a "top-level" one (not in a function of any +// If/While/etc.), this shape inference graph might have host computation to +// outside compilation placeholder nodes, which will cause shape inference to +// fail. However, those nodes are not in `host_graph` any more (because we +// have executed `PostprocessForEncapsultion`). In this case, we clear the +// graph, and copy SendFromHost with all its predecessors from `host_graph`. +// This case is detected by whether the SendFromHost node exists in +// `host_graph` as well. +// 2. Remove control edges, and prune nodes that are not useful for shape +// inference. +Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, + Graph* host_graph, + FunctionLibraryDefinition* fld) { + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(shape_inference_graph_name), AttrSlice(), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + Graph* g = fbody->graph; + + // Find SendFromHost node. + Node* send_from_host = nullptr; + for (Node* n : g->nodes()) { + if (n->type_string() == "_XlaSendFromHost") { + send_from_host = n; + break; + } + } + if (!send_from_host) { + return errors::Internal("Shape inference graph ", + shape_inference_graph_name, + " does not have _XlaSendFromHost node."); + } + + // See if the SendFromHost node exists in `host_graph`. + Node* send_from_host_main_graph = nullptr; + for (Node* n : host_graph->nodes()) { + if (n->name() == send_from_host->name()) { + send_from_host_main_graph = n; + break; + } + } + if (send_from_host_main_graph) { + // This is an "top-level" outside compilation. Clear the graph, and copy + // SendFromHost and all its predecessors from `host_graph`. + std::vector nodes; + for (Node* n : g->op_nodes()) { + nodes.push_back(n); + } + for (Node* n : nodes) { + g->RemoveNode(n); + } + + std::map node_map; + node_map[host_graph->source_node()] = g->source_node(); + Status s; + auto copy_node_fn = [&](const Node* n) { + if (!s.ok()) { + return; + } + + if (node_map.find(n) != node_map.end()) { + return; + } + + NodeDef copy_def = n->def(); + Node* copy = g->AddNode(copy_def, &s); + if (!s.ok()) { + return; + } + for (auto e : n->in_edges()) { + if (node_map.find(e->src()) == node_map.end()) { + s = errors::Internal("Cannot find node image for ", + e->src()->DebugString()); + return; + } + g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input()); + } + + node_map[n] = copy; + }; + // TODO(b/77601805): consolidate copy graph functions. + ReverseDFSFrom(*host_graph, + std::vector{send_from_host_main_graph}, + /*enter=*/nullptr, copy_node_fn, NodeComparatorID()); + if (!s.ok()) { + return s; + } + + send_from_host = node_map[send_from_host_main_graph]; + } else { + // This is an outside compilation embedded in If/While/gradient/etc. + // It will be enough for shape inference. Leave `g` unchanged. + } + + // Control edges are not useful for shape inference. Remove them. + for (auto e : g->edges()) { + if (e->IsControlEdge()) { + g->RemoveEdge(e); + } + } + // Nodes that are not reverse reachable from SendFromHost are not useful for + // shape inference. Prune them. + PruneForReverseReachability(g, + std::unordered_set{send_from_host}); + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile(shape_inference_graph_name, *g, fld); + } + + // Replace original shape inference graph. + FunctionDef fdef_replace; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, shape_inference_graph_name, &fdef_replace)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(shape_inference_graph_name, fdef_replace)); + + return Status::OK(); +} + +} // namespace + +Status RewriteOutsideCompilationSubgraphFn::operator()( + const std::vector& arg_source_tensors, + std::unique_ptr* graph, std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node_def) { + string old_name = node_def->op(); + string new_name = absl::StrCat(xla_cluster_name_, "_", old_name); + node_def->set_op(new_name); + node_def->set_name(new_name); + + // Later we will run PruneForReverseReachability(), so make sure all original + // nodes are reachable from sink node and won't be removed. + FixupSourceAndSinkEdges(graph->get()); + + // Step 1: create a key placeholder node. + TF_ASSIGN_OR_RETURN( + Node * key_placeholder, + AddHostComputeKeyPlaceholder(xla_cluster_name_, graph->get())); + + // Step 2: build RecvAtHost node, and replace all _Arg nodes with it. + std::vector recv_at_host_dtypes; + TF_ASSIGN_OR_RETURN( + Node * recv_at_host_node, + ReplaceArgNodesWithRecvAtHostNode(graph->get(), new_name, + &recv_at_host_dtypes, key_placeholder)); + + // Step 3: build SendFromHost node, and replace all _Retval nodes with it. + std::vector send_from_host_dtypes; + TF_ASSIGN_OR_RETURN( + Node * send_from_host_node, + ReplaceRetNodesWithSendFromHostNode( + graph->get(), new_name, &send_from_host_dtypes, key_placeholder)); + + // Step 4: add XLA cluster and outside compilation attr. + for (Node* n : (*graph)->nodes()) { + if (IsKeyPlaceholderNode(*n)) { + continue; + } + + n->AddAttr(xla_cluster_attr_name_, xla_cluster_name_); + n->AddAttr(outside_compilation_attr_name_, old_name); + } + + // Check whether we have all input shapes for XlaSendFromHost. If we do, we + // will set `shapes` attr for the call node; otherwise we will save the + // shape inference graph and set `shape_inference_graph` for the call node. + absl::optional> shapes = + GetInferredInputShapes(send_from_host_dtypes.size(), send_from_host_node); + for (Node* n : (*graph)->nodes()) { + n->ClearAttr(kXlaInferredShapesAttrName); + } + + // Step 5: add control edges for originally XLA <-> outside compilation + // control edges. + for (Node* n : (*graph)->nodes()) { + if (HasNodeAttr(n->def(), kXlaConnectedToXlaComputationAttrName)) { + (*graph)->AddControlEdge(n, send_from_host_node); + n->ClearAttr(kXlaConnectedToXlaComputationAttrName); + } + if (HasNodeAttr(n->def(), kXlaConnectedFromXlaComputationAttrName)) { + (*graph)->AddControlEdge(recv_at_host_node, n); + n->ClearAttr(kXlaConnectedFromXlaComputationAttrName); + } + } + + // Step 6: RecvAtHost/SendFromHost/key_placeholder might be dead nodes. Prune + // them if necessary. + // - RecvAtHost should be pruned iff it has no output data/control edges. If + // it has any output edge, it will be reverse reachable from sink node. We + // don't need to do anything special. + // - SendFromHost should be pruned iff it has no input data/control edges. If + // it has input edges other than key_placeholder, we connect it to sink + // node so it won't be pruned. + // - key_placeholder should be pruned iff RecvAtHost/SendFromHost are pruned. + // We don't need to do anything special. + if (send_from_host_node->in_edges().size() > 1) { + (*graph)->AddControlEdge(send_from_host_node, (*graph)->sink_node()); + } + PruneForReverseReachability( + graph->get(), std::unordered_set{(*graph)->sink_node()}); + + // Step 7: add necessary attributes to function call node, so we can replace + // it with HostCompute node later. + AddNodeAttr("_outside_compilation_subgraph", old_name, node_def); + if (shapes) { + AddNodeAttr("shape_inference_graph", "", node_def); + AddNodeAttr("shapes", *shapes, node_def); + } else { + string shape_inference_func_name = + absl::StrCat("_outside_compilation_shape_inference_", new_name); + AddNodeAttr("shape_inference_graph", shape_inference_func_name, node_def); + AddNodeAttr("shapes", std::vector{}, node_def); + } + AddNodeAttr("ancestors", std::vector{}, node_def); + AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def); + AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); + AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); + + return Status::OK(); +} + +Status ExtractOutsideCompilationForFunction( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const NameAttrList& func_name_attrs, const string& new_func_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + // Early return if function does not have any outside compilation nodes. + const string& func_name = func_name_attrs.name(); + const FunctionDef* fdef = fld->Find(func_name); + if (!fdef) { + return errors::Internal("Cannot find function ", func_name); + } + *has_outside_compilation = false; + for (auto& node_def : fdef->node_def()) { + if (HasNodeAttr(node_def, outside_compilation_attr_name)) { + *has_outside_compilation = true; + break; + } + } + if (!has_outside_compilation) { + return Status::OK(); + } + + // Convert the function to graph. + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fld->Find(func_name), AttrSlice(&func_name_attrs.attr()), fld, + [&](const string& op, const OpDef** sig) { + return fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + + // Preprocess edges between different outside compilations. They will be + // restored in `ConstructHostGraph()`. + TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations( + fbody->graph, outside_compilation_attr_name)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_before_", func_name), + *fbody->graph, fld); + } + + // Encapsulate outside_compilation cluster into function call node. + std::unique_ptr graph_out; + RewriteOutsideCompilationSubgraphFn rewrite_fn( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name); + TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( + outside_compilation_attr_name, "", *fbody->graph, rewrite_fn, + /*reuse_existing_functions=*/true, &graph_out, fld)); + + // Replace outside_compilation function nodes with HostCompute ops. + std::vector outside_compilation_nodes; + std::vector outside_compilation_host_graphs; + for (Node* n : graph_out->nodes()) { + if (HasNodeAttr(n->def(), "_outside_compilation_subgraph")) { + outside_compilation_nodes.push_back(n); + outside_compilation_host_graphs.push_back(n->name()); + + // If we could not infer shapes for XlaSendFromHost inputs statically, we + // will set the "shape_inference_graph" attribute. In that case, copy + // outside compilation subgraph as shape inference graph in `fld`. + string shape_inference_graph; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", + &shape_inference_graph)); + if (!shape_inference_graph.empty()) { + shape_inference_graphs->push_back(shape_inference_graph); + + const FunctionDef* xla_fdef = fld->Find(n->name()); + if (!xla_fdef) { + return errors::Internal("Cannot find XLA function ", n->name()); + } + FunctionDef shape_inference_fdef = *xla_fdef; + shape_inference_fdef.mutable_signature()->set_name( + shape_inference_graph); + if (fld->Find(shape_inference_graph)) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph, + shape_inference_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + } + } + } + } + for (Node* n : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ReplaceOrRemoveOutsideCompilationCallNode( + graph_out.get(), n, host_compute_core)); + } + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("extract_outside_compilation_for_func_after_", func_name), + *graph_out, fld); + } + + // Construct host graph. + if (!outside_compilation_host_graphs.empty()) { + TF_RETURN_IF_ERROR( + ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, + outside_compilation_host_graphs, fld, host_graph)); + } + + // Remove the outside compilation graphs from function library. + for (const string& func : outside_compilation_host_graphs) { + TF_RETURN_IF_ERROR(fld->RemoveFunction(func)); + } + + // Replace original function. + FunctionDef updated_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + if (fld->Find(new_func_name)) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); + } else { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); + } + + return Status::OK(); +} + +Status ExtractOutsideCompilation( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, + FunctionLibraryDefinition* fld) { + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("extract_outside_compilation_before", *g, fld); + } + + std::vector shape_inference_graphs; + for (auto& iter : clusters) { + string xla_cluster_name = iter.first; + Node* n = iter.second.node; + auto const& func_name_attrs = iter.second.func_name_attrs; + auto const& host_compute_core = iter.second.host_compute_core; + + bool has_outside_compilation; + std::unique_ptr host_graph; + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func_name_attrs, func_name_attrs.name(), host_compute_core, fld, + &host_graph, &shape_inference_graphs, &has_outside_compilation)); + if (host_graph) { + TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(g, host_graph.get(), n)); + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g, + fld); + } + + TF_RETURN_IF_ERROR(PostprocessForEncapsulation( + g, xla_cluster_attr_name, outside_compilation_attr_name, clusters)); + + for (auto shape_inference_graph_name : shape_inference_graphs) { + TF_RETURN_IF_ERROR( + RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld)); + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("extract_outside_compilation_after", *g, fld); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2a4f07cca213d999202024294f5d8f94527059c3 --- /dev/null +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Rewrite function for outside compilation subgraphs. It will perform the +// following steps: +// +// 1. Add a XLA computation key placeholder node (it will be used as input for +// XlaRecvAtHost and XlaSendFromHost); +// 2. Replace all _Arg nodes with one single XlaRecvAtHost node; +// 3. Replace all _Retval nodes with one single XlaSendFromHost node; +// 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name` +// and `outside_compilation_attr_name`; +// 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a +// control edge from the node to XlaSendFromHost; for nodes marked with attr +// kXlaConnectedFromXlaComputationAttrName, add a control edge from +// XlaRecvAtHost node to the node; +// 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node. +// 7. Add necessary attributes to `node_def`, so we can replace it with a +// XlaHostCompute node later. If all input shapes for XlaSendFromHost are +// known, "shapes" attr will be set to the list of input shapes; otherwise +// "shape_inference_graph" attr will be set to shape inference function name. +class RewriteOutsideCompilationSubgraphFn { + public: + RewriteOutsideCompilationSubgraphFn( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name) + : xla_cluster_attr_name_(xla_cluster_attr_name), + outside_compilation_attr_name_(outside_compilation_attr_name), + xla_cluster_name_(xla_cluster_name) {} + + Status operator()(const std::vector&, + std::unique_ptr* graph, + std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node_def); + + private: + string xla_cluster_attr_name_; + string outside_compilation_attr_name_; + string xla_cluster_name_; +}; + +// For an XLA computation function, replace all outside compilations with +// XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by +// `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one +// single host side graph (`host_graph`). +// +// xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA +// computation and outside compilation. Required for +// `RewriteOutsideCompilationSubgraphFn`. +// xla_cluster_name: XLA cluster name for this XLA computation. We need it +// because XLA cluster name might be different from `func_name`. +// func_name_attrs: they will be used to instantiate the XLA computation func. +// new_func_name: new function name for rewritten XLA computation func. +// host_compute_core: mapping from outside compilation cluster name to XLA +// device assignment. +// fld: FunctionLibraryDefinition object. +// host_graph: Graph object to store host side graph for all outside +// compilations within this XLA computation func. If there is no outside +// compilation, it will be empty. +// shape_inference_graphs: a list of outside compilation shape inference +// function names. These functions need to be rewritten later. +// has_outside_compilation: a bool indicating whether this function has any +// outside compilation nodes. +Status ExtractOutsideCompilationForFunction( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const NameAttrList& func_name_attrs, const string& new_func_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, std::unique_ptr* host_graph, + std::vector* shape_inference_graphs, bool* has_outside_compilation); + +// Rewrites XLA computation in `clusters` to replace outside compilation nodes +// with XlaHostCompute, and moves those outside compilations into `g`. If shapes +// of outside compilation outputs cannot be determined now, we will store shape +// inference graph into `fld`. +Status ExtractOutsideCompilation( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, + FunctionLibraryDefinition* fld); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bff956100da661b679b4557fce53671e6cef88c5 --- /dev/null +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -0,0 +1,441 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h" + +#include "absl/strings/match.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(RewriteOutsideCompilationSubgraphFnTest, Basic) { + // Build the graph: + // "add" = "arg0" + "arg1" + // "ret0" = "add" + // "ret1" = "arg1" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(s.WithOpName("arg1"), DT_FLOAT, 1); + Output arg2 = ops::_Arg(s.WithOpName("arg2"), DT_INT32, 2); + Output add = ops::Add(s.WithOpName("add"), arg0, arg0); + auto ret0 = ops::_Retval(s.WithOpName("ret0"), add, 0); + auto ret1 = ops::_Retval(s.WithOpName("ret1"), arg1, 1); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + Node *add_node = node_name_image["add"]; + EXPECT_NE(add_node, nullptr); + add_node->AddAttr(kXlaConnectedToXlaComputationAttrName, "cluster"); + add_node->AddAttr(kXlaConnectedFromXlaComputationAttrName, "cluster"); + + RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster"); + std::vector arg_source_tensors; + NodeDef call_node_def; + call_node_def.set_op("0"); + TF_CHECK_OK( + rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def)); + node_name_image = g->BuildNodeNameIndex(); + + // Verify step 1: add key placeholder node. + Node *key_placeholder = node_name_image["cluster_key_placeholder"]; + EXPECT_NE(key_placeholder, nullptr); + // Verify step 2: replace _Arg nodes with XlaRecvAtHost. + for (Node *n : g->nodes()) { + EXPECT_NE(n->type_string(), "_Arg"); + } + Node *recv_at_host = node_name_image["outside_compilation_cluster_0_recv"]; + EXPECT_NE(recv_at_host, nullptr); + std::vector recv_at_host_dtypes; + TF_CHECK_OK( + GetNodeAttr(recv_at_host->attrs(), "Toutputs", &recv_at_host_dtypes)); + EXPECT_EQ(recv_at_host_dtypes.size(), 3); + EXPECT_EQ(recv_at_host_dtypes[0], DT_INT32); + EXPECT_EQ(recv_at_host_dtypes[1], DT_FLOAT); + EXPECT_EQ(recv_at_host_dtypes[2], DT_INT32); + // Verify step 3: replace _Retval nodes with XlaSendFromHost. + for (Node *n : g->nodes()) { + EXPECT_NE(n->type_string(), "_Retval"); + } + Node *send_from_host = node_name_image["outside_compilation_cluster_0_send"]; + EXPECT_NE(send_from_host, nullptr); + std::vector send_from_host_dtypes; + TF_CHECK_OK( + GetNodeAttr(send_from_host->attrs(), "Tinputs", &send_from_host_dtypes)); + EXPECT_EQ(send_from_host_dtypes.size(), 2); + EXPECT_EQ(send_from_host_dtypes[0], DT_INT32); + EXPECT_EQ(send_from_host_dtypes[1], DT_FLOAT); + // Verify step 4: nodes marked with XLA cluster and outside compilation attr. + add_node = node_name_image["add"]; + EXPECT_NE(add_node, nullptr); + EXPECT_TRUE(HasNodeAttr(add_node->def(), "_xla")); + EXPECT_TRUE(HasNodeAttr(add_node->def(), "_oc")); + // Verify step 5: control edges added. + bool has_control_edge_from_recv_at_host = false; + for (auto e : add_node->in_edges()) { + if (e->IsControlEdge() && e->src() == recv_at_host) { + has_control_edge_from_recv_at_host = true; + } + } + EXPECT_TRUE(has_control_edge_from_recv_at_host); + bool has_control_edge_to_send_from_host = false; + for (auto e : add_node->out_edges()) { + if (e->IsControlEdge() && e->dst() == send_from_host) { + has_control_edge_to_send_from_host = true; + } + } + EXPECT_TRUE(has_control_edge_to_send_from_host); + // Verify step 7: necessary attrs added to call_node_def. + string shape_inference_graph; + TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), + "shape_inference_graph", &shape_inference_graph)); + EXPECT_EQ(shape_inference_graph, + "_outside_compilation_shape_inference_cluster_0"); +} + +TEST(RewriteOutsideCompilationSubgraphFnTest, NoSendFromHost) { + // Build the graph: only 1 node: "arg0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output arg0 = ops::_Arg(s.WithOpName("arg0"), DT_INT32, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster"); + std::vector arg_source_tensors; + NodeDef call_node_def; + call_node_def.set_op("0"); + TF_CHECK_OK( + rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def)); + auto node_name_image = g->BuildNodeNameIndex(); + + // Check key placeholder and RecvAtHost is present, but SendFromHost is not. + Node *key_placeholder = node_name_image["cluster_key_placeholder"]; + EXPECT_NE(key_placeholder, nullptr); + Node *recv_at_host = node_name_image["outside_compilation_cluster_0_recv"]; + EXPECT_NE(recv_at_host, nullptr); + Node *send_from_host = node_name_image["outside_compilation_cluster_0_send"]; + EXPECT_EQ(send_from_host, nullptr); +} + +TEST(RewriteOutsideCompilationSubgraphFnTest, NoRecvAtHost) { + // Build the graph: + // "ret" = "const0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster"); + std::vector arg_source_tensors; + NodeDef call_node_def; + call_node_def.set_op("0"); + TF_CHECK_OK( + rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def)); + auto node_name_image = g->BuildNodeNameIndex(); + + // Check key placeholder and SendFromHost is present, but RecvAtHost is not. + Node *key_placeholder = node_name_image["cluster_key_placeholder"]; + EXPECT_NE(key_placeholder, nullptr); + Node *recv_at_host = node_name_image["outside_compilation_cluster_0_recv"]; + EXPECT_EQ(recv_at_host, nullptr); + Node *send_from_host = node_name_image["outside_compilation_cluster_0_send"]; + EXPECT_NE(send_from_host, nullptr); +} + +TEST(RewriteOutsideCompilationSubgraphFnTest, NoKeyPlaceholder) { + // Build the graph: only 1 node: "const0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster"); + std::vector arg_source_tensors; + NodeDef call_node_def; + call_node_def.set_op("0"); + TF_CHECK_OK( + rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def)); + auto node_name_image = g->BuildNodeNameIndex(); + + // Check key placeholder/RecvAtHost/SendFromHost are not present. + Node *key_placeholder = node_name_image["cluster_key_placeholder"]; + EXPECT_EQ(key_placeholder, nullptr); + Node *recv_at_host = node_name_image["outside_compilation_cluster_0_recv"]; + EXPECT_EQ(recv_at_host, nullptr); + Node *send_from_host = node_name_image["outside_compilation_cluster_0_send"]; + EXPECT_EQ(send_from_host, nullptr); +} + +TEST(RewriteOutsideCompilationSubgraphFnTest, ShapesInferred) { + // Build the graph: + // "ret" = "const0" + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + auto ret = ops::_Retval(s.WithOpName("ret"), const0, 0); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + Node *const0_node = node_name_image["const0"]; + EXPECT_NE(const0_node, nullptr); + PartialTensorShape shape({2}); + const0_node->AddAttr(kXlaInferredShapesAttrName, + std::vector{shape}); + + RewriteOutsideCompilationSubgraphFn rewrite_fn("_xla", "_oc", "cluster"); + std::vector arg_source_tensors; + NodeDef call_node_def; + call_node_def.set_op("0"); + TF_CHECK_OK( + rewrite_fn(arg_source_tensors, &g, nullptr, nullptr, &call_node_def)); + node_name_image = g->BuildNodeNameIndex(); + + // Check "shape" attr is available in call_node_def. + std::vector shapes; + TF_CHECK_OK(GetNodeAttr(AttrSlice(&call_node_def.attr()), "shapes", &shapes)); + EXPECT_EQ(shapes.size(), 1); + EXPECT_EQ(shapes[0].dim_size(), 1); +} + +TEST(ExtractOutsideCompilationForFunctionTest, Basic) { + // Build the XLA computation func. + // "const0" + // "identity0" = "const0" (outside compilation cluster "0") + // "identity1" = "identity0" (outside compilation cluster "1") + // "identity2" = "identity1" + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + Output identity0 = ops::Identity(s.WithOpName("identity0"), const0); + Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0); + Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["identity0"]->AddAttr("_oc", "0"); + node_name_image["identity1"]->AddAttr("_oc", "1"); + PartialTensorShape shape({2}); + node_name_image["identity1"]->AddAttr( + kXlaInferredShapesAttrName, std::vector{shape}); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::unique_ptr host_graph; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", + host_compute_core, &fld, &host_graph, &shape_inference_graphs, + &has_outside_compilation)); + + // Get rewritten XLA computation function. + FunctionBody *fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + auto node_name_index = fbody->graph->BuildNodeNameIndex(); + + // Check XlaHostCompute nodes. + Node *host_compute_0 = node_name_index["outside_compilation_0_host_compute"]; + EXPECT_NE(host_compute_0, nullptr); + Node *host_compute_1 = node_name_index["outside_compilation_1_host_compute"]; + EXPECT_NE(host_compute_1, nullptr); + // Check XlaHostCompute nodes' "tpu_core" attr. + int tpu_core; + TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "tpu_core", &tpu_core)); + EXPECT_EQ(tpu_core, 1); + TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "tpu_core", &tpu_core)); + EXPECT_EQ(tpu_core, 0); + // Check XlaHostCompute nodes' "shapes" attr. "0" should not have shapes, and + // "1" should have shapes. + std::vector shapes; + TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shapes", &shapes)); + EXPECT_EQ(shapes.size(), 0); + TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes)); + EXPECT_EQ(shapes.size(), 1); + EXPECT_EQ(shapes[0].dim_size(), 1); + // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have + // empty values. + string shape_inference_graph; + TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", + &shape_inference_graph)); + EXPECT_EQ(shape_inference_graph, ""); + TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", + &shape_inference_graph)); + EXPECT_EQ(shape_inference_graph, ""); + + // Check `shape_inference_graphs`. + EXPECT_EQ(shape_inference_graphs.size(), 0); + + // Check `host_graph`: verify we have key placeholder and sequencer. + Node *key_placeholder = nullptr, *sequencer = nullptr; + for (Node *n : host_graph->nodes()) { + if (n->type_string() == "Placeholder" && + absl::EndsWith(n->name(), "_key_placeholder")) { + EXPECT_EQ(key_placeholder, nullptr); + key_placeholder = n; + } else if (HasNodeAttr(n->def(), "_xla_host_transfer_sequencer")) { + EXPECT_EQ(sequencer, nullptr); + sequencer = n; + } + } + EXPECT_NE(key_placeholder, nullptr); + EXPECT_NE(sequencer, nullptr); + // Check SendFromHost and RecvAtHost has key placeholder as input, and have + // control edge to sequencer. + int num_send_from_host = 0, num_recv_at_host = 0; + std::vector send_recv_nodes; + for (Node *n : host_graph->nodes()) { + if (n->type_string() == "_XlaSendFromHost") { + num_send_from_host++; + send_recv_nodes.push_back(n); + } else if (n->type_string() == "_XlaRecvAtHost") { + num_recv_at_host++; + send_recv_nodes.push_back(n); + } + } + EXPECT_EQ(num_send_from_host, 1); + EXPECT_EQ(num_recv_at_host, 1); + for (Node *n : send_recv_nodes) { + Node *input_node; + TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node)); + EXPECT_EQ(input_node, key_placeholder); + + bool has_control_edge_to_sequencer = false; + for (const Edge *e : n->out_edges()) { + if (e->IsControlEdge() && e->dst() == sequencer) { + has_control_edge_to_sequencer = true; + break; + } + } + EXPECT_TRUE(has_control_edge_to_sequencer); + } +} + +TEST(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { + // Build the XLA computation func. + // "const0" + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::unique_ptr host_graph; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", + host_compute_core, &fld, &host_graph, &shape_inference_graphs, + &has_outside_compilation)); + + // Check `host_graph` is empty. + EXPECT_FALSE(host_graph); +} + +TEST(ExtractOutsideCompilationForFunctionTest, XlaHostComputeRemoved) { + // Build the XLA computation func. + // "const0" + // "const1" (outside compilation clsuter "0") + FunctionDefLibrary fdl; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output const0 = ops::Const(s.WithOpName("const0"), 1, {2}); + Output const1 = ops::Const(s.WithOpName("const1"), 1, {2}); + std::unique_ptr g(new Graph(OpRegistry::Global())); + TF_CHECK_OK(s.ToGraph(g.get())); + auto node_name_image = g->BuildNodeNameIndex(); + node_name_image["const1"]->AddAttr("_oc", "0"); + + FunctionDef *xla_fdef = fdl.add_function(); + TF_CHECK_OK(GraphToFunctionDef(*g, "cluster", xla_fdef)); + } + FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); + + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::unique_ptr host_graph; + std::vector shape_inference_graphs; + bool has_outside_compilation; + NameAttrList name_attrs; + name_attrs.set_name("cluster"); + *name_attrs.mutable_attr() = attrs; + TF_CHECK_OK(ExtractOutsideCompilationForFunction( + "_xla", "_oc", "cluster", name_attrs, "cluster_rewritten", + host_compute_core, &fld, &host_graph, &shape_inference_graphs, + &has_outside_compilation)); + + // Check rewritten XLA graph: verify that we have no XlaHostCompute. + FunctionBody *fbody = nullptr; + TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("cluster_rewritten"), + AttrSlice(), &fld, + [&](const string &op, const OpDef **sig) { + return fld.LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + for (Node *n : fbody->graph->nodes()) { + EXPECT_NE(n->type_string(), "XlaHostCompute"); + } + + // Check `host_graph`: verify we have no placeholder, but we have "const1". + int num_key_placeholders = 0; + for (Node *n : host_graph->nodes()) { + if (n->type_string() == "Placeholder" && + absl::EndsWith(n->name(), "_key_placeholder")) { + num_key_placeholders++; + } + } + EXPECT_EQ(num_key_placeholders, 0); + auto node_name_index = host_graph->BuildNodeNameIndex(); + EXPECT_NE(node_name_index.find("const1"), node_name_index.end()); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..98e344b3a080aa8aab27cd41564a90427bac151e --- /dev/null +++ b/tensorflow/compiler/jit/flags.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // NOLINT + +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +BuildXlaOpsPassFlags* build_ops_flags; +DumpGraphFlags* dump_graph_flags; +MarkForCompilationPassFlags* mark_for_compilation_flags; +XlaDeviceFlags* device_flags; +XlaOpsCommonFlags* ops_flags; + +std::vector* flag_list; +std::once_flag flags_init; + +void AppendDumpGraphFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix, + "Path prefix to which graphs dumped during debugging should be " + "written."), + }; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { + std::vector new_flags = { + Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", + &mark_for_compilation_flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", + &mark_for_compilation_flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", + &mark_for_compilation_flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", + &mark_for_compilation_flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", + &mark_for_compilation_flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", + &mark_for_compilation_flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}; + flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end()); +} + +void AllocateAndParseFlags() { + build_ops_flags = new BuildXlaOpsPassFlags; + build_ops_flags->tf_xla_enable_lazy_compilation = true; + + dump_graph_flags = new DumpGraphFlags; + dump_graph_flags->tf_dump_graph_prefix = "/tmp/"; + + mark_for_compilation_flags = new MarkForCompilationPassFlags; + mark_for_compilation_flags->tf_xla_auto_jit = 0; + mark_for_compilation_flags->tf_xla_min_cluster_size = 2; + mark_for_compilation_flags->tf_xla_max_cluster_size = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_clustering_debug = false; + mark_for_compilation_flags->tf_xla_cpu_global_jit = false; + mark_for_compilation_flags->tf_xla_clustering_fuel = + std::numeric_limits::max(); + mark_for_compilation_flags->tf_xla_fusion_only = false; + + device_flags = new XlaDeviceFlags; + device_flags->tf_xla_compile_on_demand = false; + + ops_flags = new XlaOpsCommonFlags; + ops_flags->tf_xla_always_defer_compilation = false; + + flag_list = new std::vector({ + Flag("tf_xla_enable_lazy_compilation", + &build_ops_flags->tf_xla_enable_lazy_compilation, ""), + + Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + + Flag("tf_xla_always_defer_compilation", + &ops_flags->tf_xla_always_defer_compilation, ""), + }); + AppendDumpGraphFlagsInternal(flag_list); + AppendMarkForCompilationPassFlagsInternal(flag_list); + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); +} + +} // namespace + +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *build_ops_flags; +} + +DumpGraphFlags* GetDumpGraphFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return dump_graph_flags; +} + +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return mark_for_compilation_flags; +} + +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return device_flags; +} + +const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { + std::call_once(flags_init, &AllocateAndParseFlags); + return *ops_flags; +} + +void AppendMarkForCompilationPassFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendMarkForCompilationPassFlagsInternal(flag_list); +} + +void AppendDumpGraphFlags(std::vector* flag_list) { + std::call_once(flags_init, &AllocateAndParseFlags); + AppendDumpGraphFlagsInternal(flag_list); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h new file mode 100644 index 0000000000000000000000000000000000000000..5ddea588eef5270880d91623dc05893da265960a --- /dev/null +++ b/tensorflow/compiler/jit/flags.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_FLAGS_H_ + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { + +// Flags associated with the XLA bridge's mark_for_compilation_pass module. +struct MarkForCompilationPassFlags { + int32 tf_xla_auto_jit; // Control compilation of operators into XLA + // computations on CPU and GPU devices. 0 = use + // ConfigProto setting; -1 = off; 1 = on for things + // very likely to be improved; 2 = on for everything. + // Experimental. + int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA + // compilation. Ignored for operators placed + // on an XLA device or operators explicitly + // marked for compilation. + int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA + // compilation. + bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. + bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU + // via SessionOptions. + int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this + // many ops will be marked as eligible for + // clustering. + bool tf_xla_fusion_only; // This flag is effective only when global_jit_level + // is set to ON* and overrides its behavior. If + // true, enable fusion of element-wise operations + // only using XLA. +}; + +// Flags associated with the XLA bridge's xla_device module. +struct XlaDeviceFlags { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +}; + +// Flags common to the _Xla* ops and their kernels. +struct XlaOpsCommonFlags { + // If true, _XlaCompile always refuses to compile the cluster, which means the + // XLA clusters always run in the TF executor. Defaults to false. + bool tf_xla_always_defer_compilation; +}; + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to true. + bool tf_xla_enable_lazy_compilation; +}; + +// Flags for the XLA bridge's dump_graph module. +struct DumpGraphFlags { + // Path prefix to which graphs dumped during debugging should be written. + string tf_dump_graph_prefix; +}; + +// Return a pointer to the DumpGraphFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. + +// Getters for flags structs defined above. The first call to any of these +// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer +// always return the same pointer. +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); +const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); +const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); +DumpGraphFlags* GetDumpGraphFlags(); + +// Appends the flag definitions associated with +// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. +// +// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); +void AppendDumpGraphFlags(std::vector* flag_list); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce53f70b79d97ab087fefe542920b33f883632a2 --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -0,0 +1,364 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/types/optional.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace { + +// StatusOrOptional instances hold +// +// - A non-OK Status to indicate an error that needs to be propagated out of +// this pass (e.g. the Graph is malformed). +// +// - A nullopt to indicate the function that created the instance failed to do +// what it set out to do but this is not actually an error +// (e.g. TryToGetTensorFromConstOp was passed a non-Const node). +// +// - A T to indicate a successful operation. +template +using StatusOrOptional = xla::StatusOr>; + +StatusOrOptional TryToGetTensorFromConstOp(Node* n) { + if (n->type_string() != "Const") { + return {absl::nullopt}; + } + + const TensorProto* proto = nullptr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); + Tensor tensor(proto->dtype()); + TF_RET_CHECK(tensor.FromProto(*proto)); + return {tensor}; +} + +struct SliceInputs { + Output slice_op; + Output input; + Output begin; + Output size; + + // The size of the TF slice operation as a std::vector. We can always compute + // this because we only manipulate slices with a Const size. + std::vector size_as_vector; +}; + +std::vector IntTensorAsVector(const Tensor& t) { + DCHECK(t.dtype() == DT_INT32 || t.dtype() == DT_INT64); + std::vector result; + result.reserve(t.NumElements()); + for (int i = 0; i < t.NumElements(); i++) { + int64 element = t.dtype() == DT_INT32 + ? static_cast(t.flat()(i)) + : t.flat()(i); + result.push_back(element); + } + return result; +} + +// Packages up the inputs to a Slice operation into an instance of +// `SliceInputs`. +StatusOrOptional GetSliceInputs(Node* slice) { + const int kSliceInputIndex = 0; + const int kSliceBeginIndex = 1; + const int kSliceSizeIndex = 2; + + const Edge* slice_input_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceInputIndex, &slice_input_edge)); + const Edge* slice_size_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceSizeIndex, &slice_size_edge)); + const Edge* slice_begin_edge; + TF_RETURN_IF_ERROR(slice->input_edge(kSliceBeginIndex, &slice_begin_edge)); + + SliceInputs slice_inputs; + slice_inputs.input = + Output(slice_input_edge->src(), slice_input_edge->src_output()); + slice_inputs.begin = + Output(slice_begin_edge->src(), slice_begin_edge->src_output()); + slice_inputs.size = + Output(slice_size_edge->src(), slice_size_edge->src_output()); + + TF_ASSIGN_OR_RETURN(absl::optional tf_slice_size, + TryToGetTensorFromConstOp(slice_inputs.size.node())); + if (!tf_slice_size.has_value()) { + return {absl::nullopt}; + } + + if (tf_slice_size->dims() != 1) { + return {absl::nullopt}; + } + + slice_inputs.size_as_vector = IntTensorAsVector(*tf_slice_size); + return {slice_inputs}; +} + +// Casts `x` to a DT_INT64 if it isn't one already. +Output MakeInt64(const Scope& host_scope, absl::string_view name, + const Output& x) { + return x.type() == DT_INT64 + ? x + : ops::Cast(host_scope.WithOpName(name, "_s64"), x, DT_INT64); +} + +// Returns `slice_inputs` with the index and size inputs cast to DT_INT64. +SliceInputs MakeSliceIndexAndSizeInt64(const Scope& host_scope, + const SliceInputs& slice_inputs) { + SliceInputs result; + result.input = slice_inputs.input; + result.begin = MakeInt64(host_scope, "begin", slice_inputs.begin); + result.size = MakeInt64(host_scope, "size", slice_inputs.size); + result.size_as_vector = slice_inputs.size_as_vector; + return result; +} + +// This class caches emitted constants to avoid creating multiple nodes for the +// same constant value. This helps make the generated GraphDef more readable. +class ConstantCache { + public: + explicit ConstantCache(const Scope& s) : scope_(s) {} + + Output Get1DHostConstant(int64 constant) { + auto it = cache_.find(constant); + if (it == cache_.end()) { + Output new_const = + ops::Const(scope_.WithOpName("const_", constant), {constant}); + it = cache_.insert({constant, new_const}).first; + } + return it->second; + } + + private: + Scope scope_; + std::unordered_map cache_; +}; + +// Returns a node computing the size of the Slice op with inputs `slice_inputs`. +Status ComputeSliceSize(const Scope& host_scope, + const SliceInputs& slice_inputs, Output* size) { + // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. + // + // If slice_size[i] == -1 then slice_size[i] = input_size[i] - + // begin[i]. + // + // If slice_size[i] < -1 then executing the slice will throw an error, and we + // don't do anything here. We've already filtered these cases out in + // IsRewritableSlice. + + if (absl::c_all_of(slice_inputs.size_as_vector, + [](int64 i) { return i >= 0; })) { + *size = slice_inputs.size; + return Status::OK(); + } + + Output input_shape = + ops::Shape(host_scope.WithOpName("input_shape"), slice_inputs.input, + ops::Shape::OutType(DT_INT64)); + + ConstantCache constant_pool(host_scope); + + std::vector slice_size; + for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { + if (slice_inputs.size_as_vector[i] >= 0) { + slice_size.push_back( + constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i])); + continue; + } + + DCHECK_EQ(slice_inputs.size_as_vector[i], -1); + + Output begin_i = ops::Slice( + host_scope.WithOpName("begin_", i), slice_inputs.begin, + constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1)); + + Output input_shape_i = ops::Slice( + host_scope.WithOpName("input_shape_", i), input_shape, + constant_pool.Get1DHostConstant(i), constant_pool.Get1DHostConstant(1)); + + slice_size.push_back(ops::Sub(host_scope.WithOpName("slice_size_", i), + input_shape_i, begin_i)); + DCHECK_EQ(slice_size.back().type(), DT_INT64); + } + + // Trivial ConcatV2 nodes (with exactly one input) are disallowed. + *size = + slice_size.size() == 1 + ? slice_size[0] + : ops::Concat(host_scope.WithOpName("slice_size"), slice_size, + ops::Const(host_scope.WithOpName("concat_axis"), 0)); + return Status::OK(); +} + +// Terminology: "static sized" slice is a slice with the +// _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of +// these slices can be solely determined by their "size" input. +Status ConvertTensorFlowSliceToStaticShapedSlice( + Graph* g, Node* slice, const SliceInputs& slice_inputs, + absl::string_view cluster_name, Node** result) { + string host_name; + TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( + slice->assigned_device_name(), &host_name)); + + Status status; + Scope main_scope = + NewInternalScope(g, &status, /*refiner=*/nullptr) + .WithXlaCluster(string(cluster_name)) + .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); + Scope host_scope = main_scope.WithAssignedDevice(host_name); + + SliceInputs slice_inputs_int64 = + MakeSliceIndexAndSizeInt64(host_scope, slice_inputs); + + Output slice_size; + TF_RETURN_IF_ERROR( + ComputeSliceSize(host_scope, slice_inputs_int64, &slice_size)); + + *result = + ops::Slice(main_scope.WithAssignedDevice(slice->assigned_device_name()) + .WithOpName("static_shaped_slice"), + slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) + .node(); + + TF_RETURN_IF_ERROR(main_scope.status()); + + std::vector compile_time_const_inputs; + compile_time_const_inputs.push_back("size"); + (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, + compile_time_const_inputs); + return status; +} + +void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice, + Node* static_shaped_slice) { + absl::InlinedVector edges_to_remove; + std::vector slice_out_edges; + absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges)); + for (const Edge* e : slice_out_edges) { + DCHECK(e->src_output() == 0 || e->src_output() == Graph::kControlSlot); + + int src_output = e->src_output(); + int dst_input = e->dst_input(); + Node* dst = e->dst(); + g->RemoveEdge(e); + g->AddEdge(static_shaped_slice, src_output, dst, dst_input); + } + + for (const Edge* e : slice->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), static_shaped_slice); + } + } + + g->RemoveNode(slice); +} + +Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, + absl::string_view cluster_name) { + VLOG(3) << "Rewriting slice " << slice->name() + << " to a \"static shaped\" Slice"; + Node* static_shaped_slice; + TF_RETURN_IF_ERROR(ConvertTensorFlowSliceToStaticShapedSlice( + g, slice, slice_inputs, cluster_name, &static_shaped_slice)); + ReplaceTensorFlowSliceWithStaticShapedSlice(g, slice, static_shaped_slice); + return Status::OK(); +} + +// Return true if `n` is a slice we can rewrite to have a static shape +// (i.e. have the output shape only depend on the "size" input). +xla::StatusOr IsRewritableSlice(Node* n) { + if (n->type_string() != "Slice") { + return false; + } + + if (!GetXlaClusterForNode(*n).has_value()) { + // There is no need to change slice ops outside XLA clusters. + return false; + } + + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + if (!slice_inputs.has_value()) { + return false; + } + + // If slice_size[i] < -1 for any i then executing the slice will throw an + // error, and we don't do anything here. + return absl::c_all_of(slice_inputs->size_as_vector, + [](int64 size_i) { return size_i >= -1; }); +} + +Status FindAndRewriteSlices(Graph* g, bool* changed) { + std::vector slices_to_rewrite; + for (Node* n : g->nodes()) { + TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n)); + if (is_rewritable) { + slices_to_rewrite.push_back(n); + } + } + + for (Node* n : slices_to_rewrite) { + TF_ASSIGN_OR_RETURN(absl::optional slice_inputs, + GetSliceInputs(n)); + TF_RET_CHECK(slice_inputs.has_value()); + TF_RETURN_IF_ERROR( + RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n))); + } + + if (!slices_to_rewrite.empty()) { + // We've added constants to the graph; hook them up to _SOURCE. + FixupSourceAndSinkEdges(g); + } + + *changed = !slices_to_rewrite.empty(); + + return Status::OK(); +} +} // namespace + +Status IncreaseDynamismForAutoJitPass::Run( + const GraphOptimizationPassOptions& options) { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); + } + + bool changed; + TF_RETURN_IF_ERROR(FindAndRewriteSlices(options.graph->get(), &changed)); + if (changed && flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("increase_dynamism_for_auto_jit_pass", + **options.graph, options.flib_def); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..818ca948d64b0353b08f393c3bd7d874c9b2480b --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Increases the amount of "dynamism" representable by XLA clusters by rewriting +// the TensorFlow graph. This pass does the following rewrites: +// +// Slice +// ----- +// +// Slice(op, begin, size ) => +// Slice(op, begin, actual_size(op.shape(), size, begin)); +// _XlaCompileTimeConstantInputs={2} +// +// where +// +// actual_size(op_shape, size, begin)[i] = +// size[i] == -1 ? (op_shape[i] - size[i]) +// : size[i] +// +// This pass, combined with jit/partially_decluster_pass, reduces the number of +// unnecessary cluster recompilations in some common cases. After the rewrite +// shown above jit/partially_decluster_pass extracts the actual_size(...) +// computation to outside the XLA cluster, causing the cluster to be versioned +// only on the actual size of the XlaDynamicSlice. This avoids recompilation +// due to superficial changes that don't affect tensor shapes. +// +// Future Work TODO(b/111210515) +// ----------------------------- +// +// In the future we will also translate StridedSlice and Pad a similar way. +class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2f1b831ad7605237e23c15cc43b337e06265553 --- /dev/null +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -0,0 +1,405 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +using ::testing::_; +using testing::matchers::AssignedDevice; +using testing::matchers::Attr; +using testing::matchers::Const; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; +using testing::matchers::Out; + +// A fake device used to populate a DeviceSet. +class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& device_attributes) + : Device(nullptr, device_attributes) {} + + Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + + Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } + + static std::unique_ptr Make(const string& name, const string& type) { + DeviceAttributes device_attributes; + device_attributes.set_name(name); + device_attributes.set_device_type(DeviceType(type).type()); + return absl::make_unique(device_attributes); + } +}; + +const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0"; +const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0"; + +Status IncreaseDynamismForAutoJit(const Scope& s, + std::unique_ptr* result) { + std::vector> devices; + devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU)); + devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU)); + + std::unique_ptr device_set(new DeviceSet()); + for (auto& device : devices) { + device_set->AddDevice(device.get()); + } + + auto graph = absl::make_unique(OpRegistry::Global()); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + GraphOptimizationPassOptions options; + options.graph = &graph; + options.device_set = device_set.get(); + options.session_options = &session_options; + + // Scope::ToGraph seems to drop assigned devices, probably because it goes + // through a GraphDef. So explicitly maintain the device assignment. + std::unordered_map assigned_device_names; + for (Node* n : s.graph()->nodes()) { + assigned_device_names[n->name()] = n->assigned_device_name(); + } + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + for (Node* n : graph->nodes()) { + n->set_assigned_device_name(assigned_device_names[n->name()]); + } + + IncreaseDynamismForAutoJitPass rewriter; + TF_RETURN_IF_ERROR(rewriter.Run(options)); + *result = std::move(graph); + return Status::OK(); +} + +TEST(SliceToDynamicSliceRewriteTest, Basic) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + const int64 zero_64 = 0; + const int32 zero_32 = 0; + const int64 one_64 = 1; + + auto m_input = Out(NodeWith(Op("Placeholder"), Name("input"))); + auto m_begin_s64 = Out(NodeWith( + Op("Cast"), Inputs(Out(NodeWith(Op("Placeholder"), Name("begin")))))); + auto m_input_shape = Out(NodeWith(Op("Shape"), Inputs(m_input))); + auto m_slice_size_0 = Out(NodeWith( + Op("Sub"), AssignedDevice(kHostName), + Inputs( + Out(NodeWith(Op("Slice"), AssignedDevice(kHostName), + Inputs(m_input_shape, Const(zero_64), Const(one_64)))), + Out(NodeWith(Op("Slice"), AssignedDevice(kHostName), + Inputs(m_begin_s64, Const(zero_64), Const(one_64))))))); + auto m_dynamic_slice_size = Out(NodeWith( + Op("ConcatV2"), AssignedDevice(kHostName), + Inputs(m_slice_size_0, Const(static_cast(500)), Const(zero_32)))); + + std::vector compile_time_constant_inputs; + compile_time_constant_inputs.push_back("size"); + auto m_dynamic_slice = NodeWith( + Op("Slice"), AssignedDevice(kDeviceName), + Attr(kXlaCompileTimeConstantInputsAttr, compile_time_constant_inputs), + Inputs(m_input, m_begin_s64, m_dynamic_slice_size)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, m_dynamic_slice); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + EXPECT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output control_pred = ops::Placeholder(root.WithOpName("control"), DT_BOOL); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + root.graph()->AddControlEdge(control_pred.node(), slice.node()); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), + CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); +} + +int64 ToInt64(int v) { return static_cast(v); } + +TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("Cast"))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteInvalidSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + + // The shape refiner throws an error if we use a bogus constant value for + // size. So we first use a Placeholder to placate the shape refiner, and + // later replace it with a bogus constant. + Output size_placeholder = + ops::Placeholder(root.WithOpName("size_placeholder"), DT_INT32); + Output slice = + ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); + + Output size = ops::Const(root.WithOpName("size"), {-8, 500}); + TF_ASSERT_OK(root.graph()->UpdateEdge(/*new_src=*/size.node(), + /*new_src_index=*/0, + /*dst=*/slice.node(), /*dst_index=*/2)); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteUnclusteredSlice) { + Scope root = + Scope::NewRootScope().ExitOnError().WithAssignedDevice(kDeviceName); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size = ops::Const(root.WithOpName("size"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Placeholder(root.WithOpName("size"), DT_INT64); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + Output size = ops::Const(root.WithOpName("size"), {}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), "slice/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_THAT(static_shaped_slice, + NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr), + Inputs(_, _, Out(NodeWith(Name(size.node()->name())))))); +} + +TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + auto ToInt64 = [](int v) { return static_cast(v); }; + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); + + // The C++ node bindings immediately error out when we try construct a bogus + // slice so we first use a placeholder to construct the Slice and then replace + // the input. + Output size_placeholder = ops::Placeholder(root.WithOpName("size"), DT_INT64); + Output slice = + ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); + + Output size = + ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}}); + TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + EXPECT_THAT(result->nodes(), + Not(Contains(NodeWith(Op("Slice"), + Attr(kXlaCompileTimeConstantInputsAttr))))); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); + Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32); + Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500}); + Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a); + + Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200}); + Output slice_with_slice_input = ops::Slice( + root.WithOpName("slice_with_slice_input"), slice, begin, size_b); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_input/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(Out(NodeWith( + Op("Slice"), + Name("slice/static_shaped_slice/static_shaped_slice"))), + _, _))); +} + +TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) { + Scope root = Scope::NewRootScope() + .ExitOnError() + .WithAssignedDevice(kDeviceName) + .WithXlaCluster("cluster_0"); + + Output input_float = + ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT); + Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64); + + Output begin_begin = + ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32); + Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1}); + Output begin = + ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size); + + Output size = + ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)}); + Output slice_with_slice_begin = ops::Slice( + root.WithOpName("slice_with_slice_begin"), input_float, begin, size); + + std::unique_ptr result; + TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result)); + + Node* static_shaped_slice = testing::FindNodeByName( + result.get(), + "slice_with_slice_begin/static_shaped_slice/static_shaped_slice"); + ASSERT_NE(static_shaped_slice, nullptr); + EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT) + << "Expected DT_FLOAT, was " + << DataType_Name(static_shaped_slice->output_type(0)); + EXPECT_THAT( + static_shaped_slice, + NodeWith( + Op("Slice"), + Inputs(_, + Out(NodeWith( + Op("Slice"), + Name("begin/static_shaped_slice/static_shaped_slice"))), + _))); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 085c0e5adbb270e71ff3447a936555c99904e26c..f79bdc1e2e8d82c9144d1bb9923ad36d8541cbdb 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -44,17 +45,20 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + IncreaseDynamismForAutoJitPass); + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, PartiallyDeclusterPass); // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 50, BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 26cb3af9d69ba1877c67853cde28d2477d394efc..0583774714c6db7a2fa515fc8a0d304e1898db97 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_launch_util", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index accc86a86d9d3eca741994ee502bd7580ce49b2e..ad71df5a694a5f8da94675049df1062a7edb6253 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -38,12 +39,22 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + namespace tensorflow { namespace { -Status PlatformInfoFromContext(OpKernelConstruction* ctx, - XlaPlatformInfo* result) { +XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; @@ -75,16 +86,16 @@ Status PlatformInfoFromContext(OpKernelConstruction* ctx, } if (!device_allocator) { - TF_ASSIGN_OR_RETURN(se::Platform* const platform, - se::MultiPlatformManager::PlatformWithId(platform_id)); + xla::StatusOr maybe_platform = + se::MultiPlatformManager::PlatformWithId(platform_id); + OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); + xla_allocator = absl::make_unique( - platform, ctx->device()->GetAllocator({})); + maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); } - *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); - - return Status::OK(); + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -178,9 +189,8 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, : OpKernel(ctx), constants_(constants), resources_(resources), - function_(function) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(function), + platform_info_(PlatformInfoFromContext(ctx)) {} static Status BuildCompilationCache(OpKernelContext* ctx, const XlaPlatformInfo& platform_info, @@ -219,7 +229,7 @@ static Status BuildCompilationCache(OpKernelContext* ctx, static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, const XlaPlatformInfo& platform_info, absl::Span resources, - absl::Span constants, xla::LocalClient** client, + absl::Span constants, bool lazy, xla::LocalClient** client, std::map* variables, const XlaCompiler::CompilationResult** kernel, xla::LocalExecutable** executable) { @@ -241,7 +251,7 @@ static Status CompileToLocalExecutable( // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - *variables = SnapshotResourceVariables(ctx, resources); + TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables)); *client = static_cast(cache->client()); XlaCompiler::Options options; @@ -276,8 +286,13 @@ static Status CompileToLocalExecutable( // rather than a one-element tuple. compile_options.always_return_tuple = false; - return cache->Compile(options, function, constant_args, *variables, ctx, - compile_options, kernel, executable); + std::vector args; + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_args, *variables, ctx, &args)); + return cache->Compile(options, function, args, compile_options, + lazy ? XlaCompilationCache::CompileMode::kLazy + : XlaCompilationCache::CompileMode::kStrict, + kernel, executable); } void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { @@ -291,8 +306,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, &client, &variables, &kernel, - &executable)); + constants_, /*lazy=*/false, &client, + &variables, &kernel, &executable)); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -329,18 +344,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - // Helper static functions to construct parameters for // XlaLocalLaunchBase constructor from OpKernelConstruction. std::vector ConstantsVector(OpKernelConstruction* ctx) { @@ -377,7 +380,12 @@ NameAttrList FunctionAttr(OpKernelConstruction* ctx) { return *func; } -#undef OP_REQUIRES_OK_RETURN +bool MustCompileAttr(OpKernelConstruction* ctx) { + bool must_compile; + OP_REQUIRES_OK_RETURN(ctx, false, + ctx->GetAttr("must_compile", &must_compile)); + return must_compile; +} } // namespace XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) @@ -392,20 +400,59 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx), constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), - function_(FunctionAttr(ctx)) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} + function_(FunctionAttr(ctx)), + platform_info_(PlatformInfoFromContext(ctx)), + must_compile_(MustCompileAttr(ctx)) {} void XlaCompileOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaCompileOp " << def().name() + << (must_compile_ ? "(must-compile)" : ""); xla::LocalClient* client; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; std::map variables; - OP_REQUIRES_OK( - ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, - constants_, &client, &variables, &kernel, - &executable)); + bool cannot_compile_cluster; + { + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster = cannot_compile_cluster_; + } + + if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + cannot_compile_cluster) { + executable = nullptr; + } else { + Status status = CompileToLocalExecutable( + ctx, function_, platform_info_, resources_, constants_, + /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + OP_REQUIRES_OK(ctx, status); + } + + if (status.code() == error::UNIMPLEMENTED) { + LOG(WARNING) << "Compilation failed:" << status.ToString() + << ". Falling back to TF function call."; + executable = nullptr; + mutex_lock guard(cannot_compile_cluster_mu_); + cannot_compile_cluster_ = true; + } + } + + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); + + if (!executable) { + DCHECK(!must_compile_); + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.scalar()() = false; + ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({}))); + ctx->set_output(1, compilation_successful); + return; + } // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even // if it didn't have to compile the cluster because of a compilation-cache @@ -415,13 +462,6 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( client, executable, kernel, std::move(variables), constants_.size())); - Allocator* cpu_allocator = [&] { - AllocatorAttributes host_alloc_attrs; - host_alloc_attrs.set_gpu_compatible(true); - host_alloc_attrs.set_on_host(true); - return ctx->device()->GetAllocator(host_alloc_attrs); - }(); - Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); compilation_key.flat()(0) = key; @@ -432,11 +472,11 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(1, compilation_successful); } -XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); -} +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { + VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); @@ -491,6 +531,8 @@ REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); REGISTER_KERNEL_BUILDER(Name("_XlaCompile") .Device(DEVICE_GPU) .HostMemory("constants") + .HostMemory("key") + .HostMemory("compilation_successful") .HostMemory("resources"), XlaCompileOp); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 489d26eb30a66646158f39ea3fc6f55759c7f88e..7b4d4b5b4737784d4fe277d5bbe9cab79cfaf4c9 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#include + #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" @@ -33,6 +35,7 @@ namespace tensorflow { class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; explicit XlaPlatformInfo(const DeviceType device_type, se::Platform::Id platform_id, const XlaDevice::Metadata* xla_device_metadata, @@ -110,12 +113,12 @@ class XlaLocalLaunchBase : public OpKernel { protected: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; - XlaPlatformInfo platform_info_; + const NameAttrList function_; + const XlaPlatformInfo platform_info_; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -144,13 +147,23 @@ class XlaCompileOp : public OpKernel { private: // Indexes of compile-time constant inputs - std::vector constants_; + const std::vector constants_; // Indexes of resource inputs - std::vector resources_; + const std::vector resources_; - NameAttrList function_; + const NameAttrList function_; XlaPlatformInfo platform_info_; + + const bool must_compile_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; + + mutex cannot_compile_cluster_mu_; }; class XlaRunOp : public OpKernel { @@ -160,7 +173,7 @@ class XlaRunOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - XlaPlatformInfo platform_info_; + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD deleted file mode 100644 index 07c5b2318851ed506711b9ee00c66fe680a3afd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -# Legacy command line flags for the XLA bridge libraries. - -# Please do not add more flags to this package. - -# The XLA bridge libraries were written in an environment that allowed -# command-line flags to be scattered freely throughout the libraries. This -# model, while initially convenient, leads to a proliferation in unused command -# line flags in tests and binaries, and serious problems in servers, where one -# might wish parameters to be different in independent RPC calls to the same -# routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -cc_library( - name = "mark_for_compilation_pass_flags", - srcs = ["mark_for_compilation_pass_flags.cc"], - hdrs = ["mark_for_compilation_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "xla_device_flags", - srcs = ["xla_device_flags.cc"], - hdrs = ["xla_device_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc deleted file mode 100644 index 7277a1d1f8ad5fa045645ead839ab9efa01e89c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static MarkForCompilationPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new MarkForCompilationPassFlags; - flags->tf_xla_auto_jit = 0; - flags->tf_xla_min_cluster_size = 2; - flags->tf_xla_max_cluster_size = std::numeric_limits::max(); - flags->tf_xla_clustering_debug = false; - flags->tf_xla_cpu_global_jit = false; - flags->tf_xla_clustering_fuel = std::numeric_limits::max(); - flags->tf_xla_fusion_only = false; - flag_list = new std::vector( - {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions."), - Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, - "Places an artificial limit on the number of ops marked as " - "eligible for clustering."), - Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, - "enable fusion of element-wise operations only using XLA when " - "global_jit_level is ON*.")}); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the MarkForCompilationPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h deleted file mode 100644 index 2affda6ab4e0fbad32a246744fa5b38aeb629c1b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's mark_for_compilation_pass module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// mark_for_compilation_pass module. -void AppendMarkForCompilationPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// mark_for_compilation_pass module. -typedef struct { - int32 tf_xla_auto_jit; // Control compilation of operators into XLA - // computations on CPU and GPU devices. 0 = use - // ConfigProto setting; -1 = off; 1 = on for things - // very likely to be improved; 2 = on for everything. - // Experimental. - int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA - // compilation. Ignored for operators placed - // on an XLA device or operators explicitly - // marked for compilation. - int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA - // compilation. - bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. - bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU - // via SessionOptions. - int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this - // many ops will be marked as eligible for - // clustering. - bool tf_xla_fusion_only; // This flag is effective only when global_jit_level - // is set to ON* and overrides its behavior. If - // true, enable fusion of element-wise operations - // only using XLA. -} MarkForCompilationPassFlags; - -// Return a pointer to the MarkForCompilationPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc deleted file mode 100644 index 1bb2fce2dbad5bffce2e33b665b7222090d0855a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's xla_device module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static XlaDeviceFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new XlaDeviceFlags; - flags->tf_xla_compile_on_demand = false; - flag_list = new std::vector({ - Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, - "Switch a device into 'on-demand' mode, where instead of " - "autoclustering ops are compiled one by one just-in-time."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h deleted file mode 100644 index 27b22121ac1e089bd5d5a494e1e3fb60b05bc76d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ - -// Legacy flags for the XLA bridge's xla_device module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// The values of flags associated with the XLA bridge's -// xla_device module. -typedef struct { - // Switch the CPU device into "on-demand" mode, where instead of - // autoclustering ops are compiled one by one just-in-time. - // Enabling this mode by a legacy flag is a temporary mechanism. When this - // feature is battle-tested, we will switch this to be a session option. - bool tf_xla_compile_on_demand; -} XlaDeviceFlags; - -// Return a pointer to the XlaDeviceFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -XlaDeviceFlags* GetXlaDeviceFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4f0c370e65159c89c91ea58733f20f852d9acc99..6618e3a58ab7b6374ed775cd6e4e18a6a4975588 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -49,6 +49,51 @@ limitations under the License. namespace tensorflow { namespace { +// Aggregates information about what kinds of ops are allowed. +struct OperationFilter { + // Whether resource variable ops are allowed. We do not allow resource + // variable ops in called functions (either as direct TF calls or as higher + // order control flow ops) because we do not yet model their memory effects in + // jit/resource_variable_safety_analysis. + bool allow_resource_ops; + + // Whether stateful RNG ops are allowed. XLA's RNG does not have the same + // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid + // auto-clustering stateful RNG ops. + bool allow_stateful_rng_ops; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger; + + // Whether ops with dummy implementations are allowed. We avoid + // auto-clustering these ops so that the user is not surprised when XLA is + // implicitly enabled. If the user explicitly specifies to use XLA, it is fine + // to resort to a dummy implementation. Currently Assert and CheckNumerics ops + // have dummy XLA implementations. + bool allow_dummy_ops; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant; +}; + +bool IsDummyImplOp(absl::string_view op_name) { + return op_name == "Assert" || op_name == "CheckNumerics"; +} + +bool IsStatefulRandomOp(absl::string_view op_name) { + return op_name == "RandomUniform" || op_name == "RandomShuffle" || + op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || + op_name == "TruncatedNormal" || op_name == "Multinomial"; +} + +bool OpProducesOrConsumesVariant(const Node& node) { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); +} bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient @@ -101,7 +146,7 @@ const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. @@ -109,7 +154,7 @@ bool IsCompilableCall(const NodeDef& call_def, // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime) { const NameAttrList* name_attr; NodeDef call; @@ -124,7 +169,7 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop condition: " << cond_func; @@ -140,7 +185,7 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + if (!IsCompilableCall(call, jit_device_type, op_filter, depth + 1, lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop body: " << body_func; @@ -154,7 +199,7 @@ bool IsCompilableWhile(const Node& while_node, // compilable. bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, - bool allow_resource_ops, int depth, + const OperationFilter& op_filter, int depth, FunctionLibraryRuntime* lib_runtime) { if (depth > kMaxRecursionDepth) { VLOG(2) << "Rejecting " << call_def.op() @@ -195,16 +240,30 @@ bool IsCompilableCall(const NodeDef& call_def, continue; if (node->type_string() == "While") { // Handle functional While loop. - return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, - depth + 1, lib_runtime); + return IsCompilableWhile(*node, jit_device_type, op_filter, depth + 1, + lib_runtime); } - if (!allow_resource_ops && + if (!op_filter.allow_resource_ops && (HasResourceInput(*node) || HasResourceOutput(*node))) { return false; } + if (!op_filter.allow_stateful_rng_ops && + IsStatefulRandomOp(node->type_string())) { + return false; + } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + return false; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + return false; + } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + return false; + } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, - depth + 1, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " << node->name() << ": " << node->def().ShortDebugString(); return false; @@ -383,8 +442,7 @@ Status FindCompilationCandidates( BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, &compile_time_const_nodes)); - int64& fuel = - legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; // Iterate over nodes in sorted order so that compiler fuel is deterministic. // We can't simply pass op_nodes().begin() and op_nodes().end to the @@ -426,14 +484,47 @@ Status FindCompilationCandidates( CHECK( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); + + bool always_auto_cluster = registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + + OperationFilter op_filter; + op_filter.allow_resource_ops = registration->compile_resource_ops; + op_filter.allow_stateful_rng_ops = always_auto_cluster; + op_filter.allow_control_trigger = always_auto_cluster; + op_filter.allow_dummy_ops = always_auto_cluster; + op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster; + if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, - registration->compile_resource_ops, 0, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, op_filter, 0, + lib_runtime)) { VLOG(2) << "Rejecting " << node->name() << ": unsupported op " << node->type_string(); continue; } - if (!registration->compile_resource_ops && + + if (!op_filter.allow_stateful_rng_ops && + IsStatefulRandomOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": stateful random operation"; + continue; + } + if (!op_filter.allow_control_trigger && node->IsControlTrigger()) { + VLOG(2) << "Rejecting " << node->name() << ": is a control trigger op"; + continue; + } + if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { + VLOG(2) << "Rejecting " << node->name() << ": dummy op (" + << node->type_string() << ")"; + continue; + } + if (!op_filter.allow_ops_producing_or_consuming_variant && + OpProducesOrConsumesVariant(*node)) { + VLOG(2) << "Rejecting " << node->name() + << ": produces or consumes DT_VARIANT"; + continue; + } + + if (!op_filter.allow_resource_ops && (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { // We don't have a way of returning values of type DT_RESOURCE from XLA // computations so we avoid auto-clustering nodes producing DT_RESOURCE. @@ -444,6 +535,7 @@ Status FindCompilationCandidates( << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( @@ -501,9 +593,7 @@ Status FindCompilationCandidates( // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, - registration->compile_resource_ops, 0, - lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, op_filter, 0, lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. @@ -536,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( // To set compilation to be on by default, change the following line. global_jit_level = OptimizerOptions::OFF; } - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_auto_jit == -1 || (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides @@ -563,10 +652,16 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - // We can always *compile* resource operations, even if we are sometimes - // unable to auto-cluster them. - const bool compile_resource_ops = true; - return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); + // We can always *compile* resource operations, stateful RNGs and dummy ops, + // even if we are sometimes unable to auto-cluster them. + OperationFilter op_filter; + op_filter.allow_resource_ops = true; + op_filter.allow_stateful_rng_ops = true; + op_filter.allow_control_trigger = true; + op_filter.allow_dummy_ops = true; + op_filter.allow_ops_producing_or_consuming_variant = true; + + return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); } Status MarkForCompilationPass::Run( @@ -575,12 +670,9 @@ Status MarkForCompilationPass::Run( // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); - bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); bool fusion_only = flags->tf_xla_fusion_only; - VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_auto_jit = " << flags->tf_xla_auto_jit; const FunctionLibraryDefinition* fld = options.flib_def; @@ -599,9 +691,6 @@ Status MarkForCompilationPass::Run( return false; } - // If this device requires a JIT, we must say yes. - if (registration->requires_compilation) return true; - // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); @@ -638,18 +727,21 @@ Status MarkForCompilationPass::Run( return false; } - // Otherwise use the value of global_jit_level. - // Ignore enable_jit_by_default if global jit compilation for CPU - // is explicitly requested via tf_xla_cpu_global_jit flag - bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + // Otherwise use the value of global_jit_level and the device's + // autoclustering policy. bool should_compile = - (ignore_registration || registration->enable_jit_by_default) && - global_jit_level != OptimizerOptions::OFF; + registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways || + (registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && + global_jit_level != OptimizerOptions::OFF); if (!should_compile) { if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { - VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + VLOG(2) + << "Rejecting " << node->name() + << ": autoclustering for device only when requested explicitly."; } } return should_compile; @@ -879,8 +971,7 @@ Status MarkForCompilationPass::RunImpl( OptimizerOptions::GlobalJitLevel global_jit_level = GetGlobalJitLevel(options); - legacy_flags::MarkForCompilationPassFlags* flags = - legacy_flags::GetMarkForCompilationPassFlags(); + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. @@ -952,6 +1043,28 @@ Status MarkForCompilationPass::RunImpl( continue; } + // If any of the consumer's producers are on a different device, do not + // cluster these nodes. This prevents other work on this device from being + // delayed by work on other devices. We consider predecessors of the + // entire cluster rather than just the inputs to the node to prevent the + // cluster still being combined in cases where the 'to' cluster has + // multiple dependencies on the 'from' cluster and another dependency + // leads to a merging of the clusters. + // + // TODO(b/117085735): We probably want to handle the reciprocal of this + // case where a cluster is producing data for multiple devices. + bool found_split = false; + for (const auto& in_id : cycles.Predecessors(to)) { + if (in_id >= graph->num_node_ids()) continue; + + Node* in = graph->FindNodeId(in_id); + if (compilation_candidates.find(in) != compilation_candidates.cend() && + in->assigned_device_name() != node_to->assigned_device_name()) { + found_split = true; + } + } + if (found_split) continue; + // If contracting the edge would create a cycle, bail out. // However, just because we can't merge the clusters now does not mean // we won't be able to merge them in the future. @@ -1015,12 +1128,10 @@ Status MarkForCompilationPass::RunImpl( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); // Compile if this is a cluster of >= min_cluster_size compilable operators. - // Also, always compile if the operator is placed on a device that requires - // compilation, or if it contains at least one op that is marked for + // Also, always compile if it contains at least one op that is marked for // compilation that is not an Identity op. if (effective_cluster_sizes[cluster] >= min_cluster_size || - (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || - registration->requires_compilation) { + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation)) { string& name = cluster_names[cluster]; if (name.empty()) { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 2a80c745e3fcebf97bcccb03551feb3d6fb9f831..bf2c5508ea9e987e80093f4c2e15d3ff5191126f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/list_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -817,14 +818,10 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unordered_map clusters = GetClusters(*graph); - ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; - - // ctrl_trigger_a has inputs with mismatching deadness so it won't be - // clustered. ctrl_trigger_b is okay to cluster. - std::unordered_map expected_clusters( - {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); - EXPECT_EQ(clusters, expected_clusters); + // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so + // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't + // cluster it because of b/118970344. + EXPECT_TRUE(clusters.empty()); } TEST(XlaCompilationTest, RandomShape) { @@ -923,9 +920,8 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map clusters = GetClusters(*graph); - EXPECT_NE(clusters["test/shape_rng"], ""); - EXPECT_NE(clusters["test/reshape"], ""); - EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); + EXPECT_EQ(clusters["test/shape_rng"], ""); + EXPECT_EQ(clusters["test/reshape"], ""); } TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { @@ -961,5 +957,271 @@ TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); } +TEST(XlaCompilationTest, DontClusterMergingNodes) { + // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed + // on GPU1. However, it should not be clustered with the previous node on + // GPU1, because that will serialize production of its inputs that should be + // done in parallel. + // + // This graph is: + // (Const0, Const0) -> MatMul0 + // (Const1, Const1) -> MatMul1 + // (MatMul0, MatMul1) -> MatMulCombined + // + // Device0: [Const0, Const0, MatMul0] + // Device1: [Const1, Const1, MatMul1, MatMulCombined] + // + // Cluster0: [Const0, Const0, MatMul0] + // Cluster1: [Const1, Const1, MatMul1] + // Cluster2: [MatMulCombined] + Scope root = Scope::NewRootScope().ExitOnError(); + absl::string_view xla_gpu_dev0 = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + absl::string_view xla_gpu_dev1 = + "/job:worker/replica:0/task:0/device:XLA_GPU:1"; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); + Output b = ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}); + Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a); + Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b); + + Output combined = + ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { + n->set_assigned_device_name(string(xla_gpu_dev0)); + } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { + n->set_assigned_device_name(string(xla_gpu_dev1)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + // Each of the MatMuls should be in a separate cluster. + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); + EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]); + EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]); + EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]); + EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]); +} + +// TODO(b/117085735): This form of clustering should be prevented. +TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { + // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed + // on GPU0. However, it should not be clustered with the next node on + // GPU0, because that will prevent the node on GPU1 from beginning its work as + // soon as the data has been produced. + // + // This graph is: + // (Const0, Const0) -> MatMulSource + // MatMulSource -> (MatMul0, MatMul1) + // + // Device0: [Const0, Const1, MatMulSource, MatMul0] + // Device1: [MatMul1] + // + // Cluster0: [Const0, Const1, MatMulSource] + // Cluster1: [MatMul0] + // Cluster2: [MatMul1] + Scope root = Scope::NewRootScope().ExitOnError(); + absl::string_view xla_gpu_dev0 = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + absl::string_view xla_gpu_dev1 = + "/job:worker/replica:0/task:0/device:XLA_GPU:1"; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}); + Output matmul_source = + ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a); + + Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source, + matmul_source); + Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source, + matmul_source); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + for (Node* n : graph->nodes()) { + if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { + n->set_assigned_device_name(string(xla_gpu_dev0)); + } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { + n->set_assigned_device_name(string(xla_gpu_dev1)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]); + EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); + EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]); + + // Improved Heuristics should prevent this probably. + EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]); +} + +TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); + Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); + Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); + Output c = ops::Add(root.WithOpName("test/c"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/a"], ""); + EXPECT_NE(clusters["test/b"], ""); + EXPECT_NE(clusters["test/c"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200}); + Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT); + Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT); + Output c = ops::Add(root.WithOpName("test/c"), a, b); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/a"], ""); + EXPECT_EQ(clusters["test/b"], ""); +} + +TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { + absl::string_view xla_cpu_device = + "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_cpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/check"], ""); + EXPECT_NE(clusters["test/greaterequal"], ""); + EXPECT_NE(clusters["test/assert"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterDummyOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT); + Output check = + ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check"); + Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b); + Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/assert"], ""); + EXPECT_EQ(clusters["test/check"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); +} + +TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output dummy_input = + ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64); + Output variant_input = + ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT); + + // Create one more node so that we don't avoid creating a cluster solely + // because it would be trivial. + Output dummy_cast = + ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32); + + Output tensor_list_element_shape = ops::TensorListElementShape( + root.WithOpName("test/tensor_list_element_shape"), variant_input, + DT_INT32); + + root.graph()->AddControlEdge(dummy_cast.node(), + tensor_list_element_shape.node()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); +} + +TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64); + Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64); + + Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32); + Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32); + + Output tensor_list_reserve = ops::TensorListReserve( + root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(xla_cpu_device); + } + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/tensor_list_reserve"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index d56d0f8ccfcdab40003be38059228cb255921b64..64a3301745790132fe3149bf8fb52d6c45ecc3c1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -34,15 +34,9 @@ namespace tensorflow { // // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // make this more direct, but probably not worth it solely for this test. - std::vector devices; + std::vector> devices; TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); - auto delete_devices = gtl::MakeCleanup([&] { - for (Device* d : devices) { - delete d; - } - }); - GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index d8ace628e6b76e011ecddd4d526efc4db9c9237e..c788091724e443ba1e3bcd60515d68e71e2e0824 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -19,7 +19,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" namespace tensorflow { @@ -28,6 +31,7 @@ namespace matchers { namespace { using impl::NodeMatcherProperties; +using impl::OutEdge; string IndentAllButFirstLine(absl::string_view text) { std::vector lines = absl::StrSplit(text, '\n'); @@ -99,8 +103,6 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, } } -using Input = std::pair; - struct NodeMatcher : public ::testing::MatcherInterface { bool MatchAndExplain( const Node* node, @@ -191,6 +193,30 @@ struct NodeMatcher : public ::testing::MatcherInterface { } return false; } + + const AttrValueMap attr_value_map = node->def().attr(); + for (const auto& attr_kv_pair : attrs) { + auto it = attr_value_map.find(attr_kv_pair.first); + if (it == attr_value_map.end()) { + if (listener->IsInterested()) { + *listener << "did not find attribute named \"" << attr_kv_pair.first + << "\" in node"; + } + return false; + } + if (attr_kv_pair.second && + !AreAttrValuesEqual(it->second, *attr_kv_pair.second)) { + if (listener->IsInterested()) { + *listener << "attribute named " << attr_kv_pair.first + << " does not match value; expected: \"" + << SummarizeAttrValue(*attr_kv_pair.second) + << "\", found: \"" << SummarizeAttrValue(it->second) + << "\""; + } + return false; + } + } + return true; } @@ -232,7 +258,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { *os << "matching " << ss.str(); } else { int edge_idx = 0; - for (const ::testing::Matcher& matcher : (*input_matchers)) { + for (const ::testing::Matcher& matcher : (*input_matchers)) { *os << "\n [" << edge_idx << "] matching ("; ::std::stringstream ss; matcher.DescribeTo(&ss); @@ -250,6 +276,21 @@ struct NodeMatcher : public ::testing::MatcherInterface { control_dep_set->DescribeTo(os); } + if (!attrs.empty()) { + printed_something = true; + std::vector attrs_str; + absl::c_transform( + attrs, std::back_inserter(attrs_str), + [](const std::pair>& attr_kv_pair) { + return absl::StrCat(attr_kv_pair.first, "->", + attr_kv_pair.second + ? SummarizeAttrValue(*attr_kv_pair.second) + : "*"); + }); + *os << " and attr values matching [" << absl::StrJoin(attrs_str, ", ") + << "]"; + } + if (!printed_something) { *os << "is any node"; } @@ -266,7 +307,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { } ::testing::StringMatchResultListener inner_listener; - Input input = {edge->src(), edge->src_output()}; + OutEdge input = {edge->src(), edge->src_output()}; if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { return true; } @@ -286,22 +327,24 @@ struct NodeMatcher : public ::testing::MatcherInterface { absl::optional name; absl::optional assigned_device; absl::optional constant_value; - absl::optional>> input_matchers; + absl::optional>> input_matchers; absl::optional<::testing::Matcher>> control_dep_set; + std::map> attrs; }; // Matches a dst and dst_output on an input edge. Today we only use this with // dst_output=0 but we will eventually need to support multi-output operations. -class InputMatcher : public ::testing::MatcherInterface { +class OutEdgeMatcher : public ::testing::MatcherInterface { public: - InputMatcher(::testing::Matcher src_matcher, int src_output) - : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} + OutEdgeMatcher(::testing::Matcher src_matcher, int src_oidx) + : src_matcher_(std::move(src_matcher)), src_oidx_(src_oidx) {} bool MatchAndExplain( - Input input, ::testing::MatchResultListener* listener) const override { + OutEdge out_edge, + ::testing::MatchResultListener* listener) const override { ::testing::StringMatchResultListener inner_listener; - if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { + if (!src_matcher_.MatchAndExplain(out_edge.first, &inner_listener)) { if (listener->IsInterested()) { *listener << "\nsource does not match expected "; src_matcher_.DescribeTo(listener->stream()); @@ -312,10 +355,10 @@ class InputMatcher : public ::testing::MatcherInterface { } return false; } - if (input.second != src_output_) { + if (out_edge.second != src_oidx_) { if (listener->IsInterested()) { - *listener << "\nexpected output slot to be " << src_output_ - << " but found " << input.second; + *listener << "\nexpected output slot to be " << src_oidx_ + << " but found " << out_edge.second; } return false; } @@ -324,31 +367,21 @@ class InputMatcher : public ::testing::MatcherInterface { } void DescribeTo(::std::ostream* os) const override { - if (src_output_) { - *os << "output slot: " << src_output_ << ", source: ("; + if (src_oidx_) { + *os << "output slot: " << src_oidx_ << ", source: ("; } src_matcher_.DescribeTo(os); - if (src_output_) { + if (src_oidx_) { *os << ")"; } } private: ::testing::Matcher src_matcher_; - int src_output_; + int src_oidx_; }; - -std::vector<::testing::Matcher> NodeMatchersToInputMatchers( - absl::Span> node_matchers) { - std::vector<::testing::Matcher> result; - absl::c_transform(node_matchers, std::back_inserter(result), - [](::testing::Matcher n) { - return ::testing::MakeMatcher(new InputMatcher(n, 0)); - }); - return result; -} } // namespace ::testing::Matcher impl::NodeWith( @@ -375,10 +408,9 @@ std::vector<::testing::Matcher> NodeMatchersToInputMatchers( matcher->assigned_device = prop.assigned_device(); } - if (prop.input_nodes()) { + if (prop.inputs()) { DCHECK(!matcher->input_matchers); - matcher->input_matchers = - NodeMatchersToInputMatchers(*prop.input_nodes()); + matcher->input_matchers = *prop.inputs(); } if (prop.control_deps()) { @@ -386,6 +418,11 @@ std::vector<::testing::Matcher> NodeMatchersToInputMatchers( matcher->control_dep_set = ::testing::UnorderedElementsAreArray(*prop.control_deps()); } + + if (prop.attr()) { + auto insert_result = matcher->attrs.insert(*prop.attr()); + DCHECK(insert_result.second); + } } return ::testing::MakeMatcher(matcher); @@ -412,12 +449,12 @@ impl::NodeMatcherProperties AssignedDevice(string assigned_device) { } impl::NodeMatcherProperties impl::Inputs( - absl::Span> inputs) { - std::vector<::testing::Matcher> inputs_vector; + absl::Span> inputs) { + std::vector<::testing::Matcher> inputs_vector; absl::c_copy(inputs, std::back_inserter(inputs_vector)); impl::NodeMatcherProperties props; - props.set_input_nodes(std::move(inputs_vector)); + props.set_inputs(std::move(inputs_vector)); return props; } @@ -431,6 +468,45 @@ impl::NodeMatcherProperties impl::CtrlDeps( return props; } +std::pair impl::AttrLiteralHelper( + const std::pair& bool_attr) { + AttrValue attr_value; + attr_value.set_b(bool_attr.second); + return {bool_attr.first, attr_value}; +} + +std::pair impl::AttrLiteralHelper( + const std::pair>& int_list_attr) { + AttrValue attr_value; + AttrValue::ListValue* list = attr_value.mutable_list(); + for (int i : int_list_attr.second) { + list->add_i(i); + } + return {int_list_attr.first, attr_value}; +} + +std::pair impl::AttrLiteralHelper( + const std::pair>& string_list_attr) { + AttrValue attr_value; + AttrValue::ListValue* list = attr_value.mutable_list(); + for (string s : string_list_attr.second) { + list->add_s(s); + } + return {string_list_attr.first, attr_value}; +} + +impl::NodeMatcherProperties impl::Attr(std::pair attr) { + impl::NodeMatcherProperties props; + props.set_attr(std::move(attr)); + return props; +} + +impl::NodeMatcherProperties impl::Attr(string name) { + impl::NodeMatcherProperties props; + props.set_attr({std::move(name), absl::nullopt}); + return props; +} + NodeMatcherProperties ConstantValue( const ::tensorflow::Input::Initializer& val) { TF_CHECK_OK(val.status); @@ -439,9 +515,13 @@ NodeMatcherProperties ConstantValue( return props; } -::testing::Matcher Const( +::testing::Matcher Const( const ::tensorflow::Input::Initializer& val) { - return NodeWith(ConstantValue(val)); + return Out(NodeWith(ConstantValue(val))); +} +::testing::Matcher Out( + int oidx, ::testing::Matcher node_matcher) { + return ::testing::MakeMatcher(new OutEdgeMatcher(node_matcher, oidx)); } } // namespace matchers @@ -455,4 +535,7 @@ Node* FindNodeByName(Graph* g, absl::string_view name) { return nullptr; } } // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } +void PrintTo(Node* n, ::std::ostream* os) { *os << SummarizeNode(*n); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index 0437a7e95c1eb3bdcdbe24a440dd90a5943c0894..0d4f02c236bba353799f75ee91cf03235b424b29 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -19,7 +19,7 @@ limitations under the License. // // tensorflow::Node* node = ...; // EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), -// Inputs(NodeWith(Name("input"))))) +// Inputs(Out(3, NodeWith(Name("input")))))) // // Matchable node properties (the expressions that go inside NodeWith(...)) // are: @@ -32,7 +32,8 @@ limitations under the License. // - AssignedDevice(string): matches the assigned device exactly. // // - Inputs(): matches the list of non-control inputs to the node -// exactly (i.e. does not match a suffix or a prefix). +// exactly (i.e. does not match a suffix or a prefix) where each element +// matches an output of a node (see Out(idx, node) below). // // - CtrlDeps(): matches the list of control dependences on the // node exactly but in any order. @@ -40,10 +41,16 @@ limitations under the License. // - ConstantValue(tensorflow::Input::Initializer init): matches a Const node // with the constant value `init`. Implies Op("Const"). // -// Node properties may not be repeated in a single NodeWith(...) matcher. -// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue -// implies Op("Const"), a single NodeWith matcher can't have both -// ConstantValue(...) and Op(...). +// - Attr(name, value): Matches a single attribute with name `name` and value +// `value`. Right now only boolean values are supported. +// +// Overlapping node properties may not be repeated in a single NodeWith(...) +// matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since +// ConstantValue implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). Multiple Attr() values can be combined as +// long as the attribute names are different. +// +// Out(idx, node) matches the `idx`'th output of a node that matches `node`. #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ @@ -66,6 +73,8 @@ namespace matchers { namespace impl { +using OutEdge = std::pair; + // ----------------------------------------------------------------------------- // Implementation details. @@ -74,6 +83,8 @@ namespace impl { class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; + using InputSeqMatcher = std::vector<::testing::Matcher>; + using AttrKeyValuePair = std::pair>; const absl::optional& name() const { return name_; } const absl::optional& op() const { return op_; } @@ -83,12 +94,13 @@ class NodeMatcherProperties { const absl::optional& constant_value() const { return constant_value_; } - const absl::optional& input_nodes() const { - return input_nodes_; + const absl::optional& inputs() const { + return input_matchers_; } const absl::optional& control_deps() const { return control_deps_; } + const absl::optional& attr() const { return attr_; } void set_name(string name) { DCHECK(IsEmpty()); @@ -111,9 +123,9 @@ class NodeMatcherProperties { op_ = "Const"; } - void set_input_nodes(NodeSeqMatcher input_nodes) { + void set_inputs(InputSeqMatcher inputs) { DCHECK(IsEmpty()); - input_nodes_ = std::move(input_nodes); + input_matchers_ = std::move(inputs); } void set_control_deps(NodeSeqMatcher control_deps) { @@ -121,9 +133,14 @@ class NodeMatcherProperties { control_deps_ = std::move(control_deps); } + void set_attr(AttrKeyValuePair attr) { + DCHECK(IsEmpty()); + attr_ = std::move(attr); + } + bool IsEmpty() const { - return !name().has_value() && !op().has_value() && - !input_nodes().has_value() && !control_deps().has_value(); + return !name().has_value() && !op().has_value() && !inputs().has_value() && + !control_deps().has_value() && !attr().has_value(); } private: @@ -131,18 +148,31 @@ class NodeMatcherProperties { absl::optional op_; absl::optional assigned_device_; absl::optional constant_value_; - absl::optional input_nodes_; + absl::optional input_matchers_; absl::optional control_deps_; + absl::optional attr_; }; ::testing::Matcher NodeWith( absl::Span props); impl::NodeMatcherProperties Inputs( - absl::Span> inputs); + absl::Span> inputs); impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); + +impl::NodeMatcherProperties Attr(std::pair attrs); +impl::NodeMatcherProperties Attr(string name); + +std::pair AttrLiteralHelper( + const std::pair& bool_attr); + +std::pair AttrLiteralHelper( + const std::pair>& int_list_attr); + +std::pair AttrLiteralHelper( + const std::pair>& string_list_attr); } // namespace impl // ----------------------------------------------------------------------------- @@ -157,6 +187,17 @@ impl::NodeMatcherProperties Op(string op); // Matches a node with assigned device `assigned_device`. impl::NodeMatcherProperties AssignedDevice(string assigned_device); +// Matches a node with a boolean typed attrbute named `name` and with value +// `value`. +template +impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { + return impl::Attr({impl::AttrLiteralHelper({name, value})}); +} + +inline impl::NodeMatcherProperties Attr(const string& name) { + return impl::Attr(name); +} + // Matches a node with inputs `inputs`. // // `inputs` are ordered; `inputs`[i] must match input i. @@ -165,6 +206,16 @@ impl::NodeMatcherProperties Inputs(Ts... inputs) { return impl::Inputs({inputs...}); } +// Matches the `idx`'th output of a node that matches `node`. +::testing::Matcher Out(int oidx, + ::testing::Matcher node); + +// Matches the first output of a node that matches `node`. +inline ::testing::Matcher Out( + ::testing::Matcher node) { + return Out(0, node); +} + // Matches a node with control dependences `control_deps`. // // `control_deps` are unordered and will match the control deps of a node in any @@ -185,13 +236,16 @@ template return impl::NodeWith(array); } -::testing::Matcher Const( +::testing::Matcher Const( const ::tensorflow::Input::Initializer& val); } // namespace matchers // If `g` has a node named `name` returns it, otherwise returns null. Node* FindNodeByName(Graph* g, absl::string_view name); } // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os); +void PrintTo(Node* n, ::std::ostream* os); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc index 93a8994307b38ac240c22d0a18268638ac7620ae..c3f0dfece85573d71dbfa21eba5af70b674fe71e 100644 --- a/tensorflow/compiler/jit/node_matchers_test.cc +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/math_ops.h" namespace tensorflow { @@ -27,12 +29,14 @@ namespace { using ::testing::_; using testing::matchers::AssignedDevice; +using testing::matchers::Attr; using testing::matchers::ConstantValue; using testing::matchers::CtrlDeps; using testing::matchers::Inputs; using testing::matchers::Name; using testing::matchers::NodeWith; using testing::matchers::Op; +using testing::matchers::Out; template string Explain(const T& t, const M& m) { @@ -61,7 +65,7 @@ TEST(NodeMatchers, CheckAgainstConstant) { "\nexpected op Add but found Placeholder"); EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), "\nexpected name add but found placeholder"); - EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))), + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(Out(NodeWith())))), "\nexpected 1 inputs but node has 0"); } @@ -74,18 +78,19 @@ TEST(NodeMatchers, CheckAgainstBinary) { ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); - EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"), - Inputs(NodeWith(Name("placeholder_a")), - NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(add.node(), + NodeWith(Op("Add"), Name("add"), + Inputs(Out(NodeWith(Name("placeholder_a"))), + Out(NodeWith(Name("placeholder_b")))))); EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), "\nexpected 0 inputs but node has 2"); EXPECT_EQ( - Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))), + Explain(add.node(), NodeWith(Inputs(Out(NodeWith(Name("blah"))), _))), "\ninput 0 does not match expected:\nname: blah, \nsource does not match " "expected name: blah\n\t\nexpected name blah but found placeholder_a"); EXPECT_EQ( - Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))), + Explain(add.node(), NodeWith(Inputs(_, Out(NodeWith(Name("blah")))))), "\ninput 1 does not match expected:\nname: blah, \nsource does not match " "expected name: blah\n\t\nexpected name blah but found placeholder_b"); } @@ -174,6 +179,36 @@ TEST(NodeMatchers, AssignedDevice) { "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); } +TEST(NodeMatchers, OutputIndices) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output pred = ops::Placeholder(root.WithOpName("pred"), DT_BOOL); + + Output data = ops::Placeholder(root.WithOpName("data"), DT_FLOAT); + ops::Switch sw(root.WithOpName("switch"), data, pred); + Output add = ops::Add(root.WithOpName("add"), sw.output_true, + ops::Placeholder(root.WithOpName("addend"), DT_FLOAT)); + + EXPECT_THAT(add.node(), NodeWith(Inputs(Out(1, NodeWith(Op("Switch"))), _))); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(Out(0, NodeWith(Op("Switch"))), _))), + "\ninput 0 does not match expected:\nop: Switch, \nexpected output slot " + "to be 0 but found 1"); +} + +TEST(NodeMatchers, Attrs) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output enter = ops::internal::Enter( + root.WithOpName("enter"), + ops::Placeholder(root.WithOpName("data"), DT_FLOAT), "frame_name", + ops::internal::Enter::Attrs{}.IsConstant(true)); + EXPECT_THAT(enter.node(), NodeWith(Attr("is_constant", true))); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("is_constant", false))), + "attribute named is_constant does not match value; expected: " + "\"false\", found: \"true\""); + EXPECT_EQ(Explain(enter.node(), NodeWith(Attr("missing_attr", false))), + "did not find attribute named \"missing_attr\" in node"); +} + } // namespace } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index f72224545b25bc7100e0b6788e6fbf0a7ca63dad..64409d9334751e0edfce9091a4e5697dd2c712c5 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -18,3 +18,9 @@ tf_gen_op_wrapper_py( out = "xla_ops.py", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) + +py_library( + name = "xla_ops_grad", + srcs = ["xla_ops_grad.py"], + deps = ["//tensorflow/python:framework_ops"], +) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index bcd1a29b1ff789b5674a21ff66cc6d23a809afc5..95d12e95fd9a0d1cca513ee74a0651ea69eba89e 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -54,6 +54,7 @@ REGISTER_OP("XlaClusterOutput") REGISTER_OP("_XlaCompile") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") + .Attr("must_compile: bool") .Input("args: Targs") .Attr("Targs: list(type) >= 0") .Input("resources: Nresources * resource") @@ -71,8 +72,12 @@ that _XlaRun can use to look up the LocalExecutable and execute it. key: A key that can be used to look up the local executable compiled by the node and associated metadata. -compilation_successful: True iff the compilation was successful. Always true -for now. +compilation_successful: If the `must_compile` attr is false the _XlaCompile op + can decide not to compile the clusters based on some profitability + heuristics. In that case `compilation_successful` is false if _XlaCompile + chose not to compile the cluster. If the `must_compile` attr is true then + _XlaCompile always attempts to compile the cluster and + `compilation_successful` is always true. )"); REGISTER_OP("_XlaRun") diff --git a/tensorflow/compiler/jit/ops/xla_ops_grad.py b/tensorflow/compiler/jit/ops/xla_ops_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..2d31d8dc714307a48932d061fb1af643940a0872 --- /dev/null +++ b/tensorflow/compiler/jit/ops/xla_ops_grad.py @@ -0,0 +1,29 @@ +"""Gradients for XLA ops.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops + + +@ops.RegisterGradient("XlaClusterOutput") +def _XlaClusterOutputGrad(_, grad): + del grad # unused + raise RuntimeError("Gradient computation of graph in xla.compile() is " + "prohibited because it can cause performance degradation." + "Please move gradient computation inside xla.compile().") diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index b1f9e9088f391cb8813d2c82395ffcc0b2081cae..42ea3926e16ae791dbe1bede3b8742383db7667c 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -22,9 +22,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { + +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } + +namespace reduce_device_to_host_copies { Status FindNodesToDecluster(const Graph& graph, absl::flat_hash_set* result, absl::Span post_order) { @@ -132,11 +137,13 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { graph->RemoveEdge(out_edge_to_clone); } + if (n->out_edges().empty()) { + graph->RemoveNode(n); + } + return Status::OK(); } -bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } - // Clones nodes to outside their cluster to avoid device-to-host copies. For // instance, converts this: // @@ -163,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } // where the ===> arrow has a hostmem source and destination and would entail a // device to host copy if the source and destination were not in the same XLA // cluster. -Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been @@ -190,6 +197,10 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { } } + // Recompute post order since PartiallyDeclusterNode may have deleted nodes. + post_order.clear(); + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); nodes_to_partially_decluster.clear(); TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); @@ -197,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { return Status::OK(); } +} // namespace reduce_device_to_host_copies +namespace reduce_recompilation { bool IsIntraClusterEdge(const Edge& edge) { absl::optional src_cluster_name = GetXlaClusterForNode(*edge.src()); @@ -206,18 +219,28 @@ bool IsIntraClusterEdge(const Edge& edge) { return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; } -Status MustCompileNode(const Node* n, bool* result) { +bool IsMustCompileDevice(const DeviceType& device_type) { + const XlaOpRegistry::DeviceRegistration* registration; + if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + return registration->autoclustering_policy == + XlaOpRegistry::AutoclusteringPolicy::kAlways; + } + + return false; +} + +Status MustCompileNode(const Node* n, bool* must_compile) { DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(n->assigned_device_name(), &device_type)); - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - *result = false; - } else { - *result = registration->requires_compilation; + if (IsMustCompileDevice(device_type)) { + *must_compile = true; + return Status::OK(); } + // We must compile `n` if it does not have a TensorFlow kernel. + *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok(); return Status::OK(); } @@ -250,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* result) { // regress performance in any significant manner. We will have to revisit this // algorith with a more complex cost model if this assumption turns out to be // incorrect. -Status DeclusterNodesToReduceRecompilations(Graph* graph) { +Status PartiallyDeclusterGraph(Graph* graph) { std::vector compile_time_const_nodes(graph->num_node_ids()); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); @@ -303,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) { return Status::OK(); } - +} // namespace reduce_recompilation } // namespace Status PartiallyDeclusterPass::Run( @@ -315,8 +338,9 @@ Status PartiallyDeclusterPass::Run( Graph* graph = options.graph->get(); - TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); - TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + TF_RETURN_IF_ERROR( + reduce_device_to_host_copies::PartiallyDeclusterGraph(graph)); + TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 0feb73a89e7050e8c413e5a733da1d87775b0ba3..38a54cc5efae35ad77b6dc8039c653e920cfc071 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" @@ -385,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(s.ToGraph(graph.get())); // This is needed to register the XLA_GPU device. - std::vector devices; + std::vector> devices; TF_ASSERT_OK(DeviceFactory::AddDevices( SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); @@ -399,10 +400,64 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { TF_ASSERT_OK(PartiallyDecluster(&graph)); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output dynamic_slice_operand = + ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32, + ops::Placeholder::Attrs{}); + Output dynamic_slice_begin = ops::Placeholder( + s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice_size = ops::Placeholder( + s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{}); + Output dynamic_slice = + ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand, + dynamic_slice_begin, dynamic_slice_size); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = + ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice); + + AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + Node* n = FindNodeByName(*graph, "dynamic_slice"); + ASSERT_NE(n, nullptr); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} - for (Device* d : devices) { - delete d; +TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) { + const char* const kClusteredProducer0Name = "ClusteredProducer0"; + const char* const kClusteredProducer1Name = "ClusteredProducer1"; + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName(kClusteredProducer0Name)); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName(kClusteredProducer1Name)); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr); + EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr); } } // namespace diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h deleted file mode 100644 index 7c8c04152d2f3a0fd46711df24756b7e68b967ea..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ -#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ - -#include -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" - -namespace tensorflow { - -// A thread-safe, first-in-first-out queue. -template -class ProducerConsumerQueue { - public: - ProducerConsumerQueue() - : capacity_(std::numeric_limits::max()) {} - ~ProducerConsumerQueue() = default; - - // Wait until the queue is non-full, then append a copy of v. - void Put(const T &v); - - // Wait until the queue is non-empty, then remove and return the head value. - T Get(); - - // If the queue is non-empty, remove the head value, placing it in *pv, and - // return true; otherwise return false. - bool TryGet(T *pv); - - // Set the capacity of the queue; the queue is full whenever count() >= - // capacity(). The initial value is the maximum size_t. Requires size > 0. - void set_capacity(std::size_t size); - - // Return the capacity of the queue. - std::size_t capacity() const; - - // Return the number of elements in the queue. - std::size_t count() const; - - // Implementation details follow. Clients should ignore. - private: - mutable tensorflow::mutex mu_; // protects all fields below - tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); - tensorflow::condition_variable non_full_ GUARDED_BY(mu_); - std::size_t capacity_ GUARDED_BY(mu_); - std::deque queue_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); -}; - -// ------------------------------------------------------ -// Implementation details follow. Clients should ignore. - -// Wait until the queue is non-full, then append a copy of v. -template -void ProducerConsumerQueue::Put(const T &v) { - mutex_lock lock(mu_); - while (queue_.size() >= capacity_) { - non_full_.wait(lock); - } - queue_.push_back(v); - non_empty_.notify_one(); -} - -// Wait until the queue is non-empty, then remove and return the head value. -template -T ProducerConsumerQueue::Get() { - mutex_lock lock(mu_); - while (queue_.empty()) { - non_empty_.wait(lock); - } - non_full_.notify_one(); - T result_value = queue_.front(); - queue_.pop_front(); - return result_value; -} - -// If the queue is non-empty, remove the head value, placing it in *pv, and -// return true; otherwise return false. -template -bool ProducerConsumerQueue::TryGet(T *pv) { - mutex_lock lock(mu_); - bool got_element = !queue_.empty(); - if (got_element) { - non_full_.notify_one(); - *pv = queue_.front(); - queue_.pop_front(); - } - return got_element; -} - -// Set the capacity of the queue; the queue is full whenever count() >= -// capacity(). The initial value is the maximum size_t. Requires size > 0. -template -void ProducerConsumerQueue::set_capacity(std::size_t size) { - mutex_lock lock(mu_); - CHECK_NE(size, 0); - capacity_ = size; - non_full_.notify_all(); -} - -// Return the capacity of the queue. -template -std::size_t ProducerConsumerQueue::capacity() const { - mutex_lock lock(mu_); - std::size_t max_elements = capacity_; - return max_elements; -} - -// Return the number of elements in the queue. -template -std::size_t ProducerConsumerQueue::count() const { - mutex_lock lock(mu_); - std::size_t num_elements = queue_.size(); - return num_elements; -} -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc deleted file mode 100644 index f61260c6e52756ee039829afdc7452f5f760c221..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/producer_consumer_queue_test.cc +++ /dev/null @@ -1,139 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/producer_consumer_queue.h" - -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -typedef ProducerConsumerQueue IntQueue; - -// Insert integers between low inclusive and high exclusive into q. -void PushRange(IntQueue *q, int low, int high) { - while (low != high) { - q->Put(low); - VLOG(2) << "Pushing " << low; - ++low; - } -} - -// Push the numbers between 0 and 999 inclusive from several threads in the -// pool. -void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { - VLOG(1) << "Adding 20-36"; - pool->Schedule([queue] { PushRange(queue, 20, 36); }); - VLOG(1) << "Adding 7-20"; - pool->Schedule([queue] { PushRange(queue, 7, 20); }); - VLOG(1) << "Adding 36-501"; - pool->Schedule([queue] { PushRange(queue, 36, 501); }); - VLOG(1) << "Adding 501-1000"; - pool->Schedule([queue] { PushRange(queue, 501, 1000); }); - VLOG(1) << "Adding 0-5"; - pool->Schedule([queue] { PushRange(queue, 0, 5); }); - VLOG(1) << "Adding 5-7"; - pool->Schedule([queue] { PushRange(queue, 5, 7); }); -} - -// Pop elements from queue using Get(). Make sure that exactly elements -// were present and their values are all integers between 0 and high-1 -// inclusive. -void GetRange(IntQueue *queue, int high) { - VLOG(1) << "Testing Wait"; - std::vector results; - for (int i = 0; i != high; ++i) { - int r = queue->Get(); - VLOG(2) << "Waited and got " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK(results[i] == i); - } -} - -// Pop elements from queue using TryGet(). Make sure that exactly -// elements were present and their values are all integers between 0 and high-1 -// inclusive. -void TryGetRange(IntQueue *queue, int high) { - std::vector results; - // Give up if we don't get all the elements back from the queue - // in 10 seconds. - int timeout = 10; - int r; - for (int i = 0; i != high; ++i) { - while (!queue->TryGet(&r)) { - if (!timeout--) { - LOG(FATAL) << "Can't find all elements in the queue"; - } - VLOG(1) << "Sleeping for a second..."; - sleep(1); - } - VLOG(2) << "Popped " << r; - results.push_back(r); - } - CHECK_EQ(queue->count(), 0); - CHECK(!queue->TryGet(&r)); - std::sort(results.begin(), results.end()); - for (int i = 0; i != high; ++i) { - CHECK_EQ(i, results[i]); - } -} - -const int kNumThreads = 15; - -TEST(ProducerConsumerQueue, GetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - GetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, TryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - PushRanges(&queue, &pool); - } - TryGetRange(&queue, 1000); -} - -TEST(ProducerConsumerQueue, ParallelGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { GetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -TEST(ProducerConsumerQueue, ParallelTryGetRange) { - IntQueue queue; - { - thread::ThreadPool pool(Env::Default(), "test", kNumThreads); - pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); - PushRanges(&queue, &pool); - } -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index e039d46ec863920eb7deb5bc20525fdab866415c..c0897217bcbd895003ce3018835da93a779a51a2 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -39,8 +39,7 @@ limitations under the License. // resource variables). // // The result is incorrect around loops because we ignore edges from -// NextIteration to Merge, but that should be fine because we don't cluster -// these edges. For instance, in: +// NextIteration to Merge. For instance, in: // // Init -----> Merge <-------+ // | | @@ -55,21 +54,20 @@ limitations under the License. // // we won't put (Read, Write) in the returned set. This is fine if // auto-clustering can only cluster the Read->Write edge, but it is a problem if -// it clusters the Write->NextIteration->Merge->Read edges instead. The same -// problem is present for the functional version of the loop above. We rely on -// auto-clustering to not cluster control flow edges like NextIteration->Merge. -// This is enough to avoid the explicit-control-flow problem shown above. One -// way to think about this is that we only care about cases where two nodes, A -// and B, would normally have been put in the same cluster but cannot legally be -// in the same cluster because of resourcevar-dependencies. If A and B would +// it clusters the Write->NextIteration->Merge->Read edges instead. So we rely +// on auto-clustering to not cluster NextIteration->Merge edges. The same +// problem is present for the functional version of the loop above and we also +// rely on auto-clustering not clustering functional while loops containing +// resource operations. +// +// One way to think about this is that we only care about cases where two nodes, +// A and B, would normally have been put in the same cluster but cannot legally +// be in the same cluster because of resourcevar-dependencies. If A and B would // normally have been put in the same cluster then all paths between A and B // would have to be clusterable (otherwise we'd have introduced a cycle). Ergo // there could not have been a NextIteration->Merge edge between A and B since // we don't cluster these edges. // -// We also rely on auto-clustering to not cluster functional control flow nodes -// that contain resource operations. -// // IMPLEMENTATION // -------------- // @@ -152,13 +150,12 @@ Status XlaResourceOpKindForNode( // can be represented by an XLA cluster and needs no special handling around // auto-jit. bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { - // XLA clusters forces all reads to happen before all writes, which means the - // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, - // Modify->Write, Read->Read, Write->Write. - // - // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write - // dependencies. - return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; + // XLA clusters force all reads to happen before all writes. Moreover the set + // of reads are executed as one atomic operation, and the set of writes are as + // another atomic operation. This means we can faithfully represent the + // following edges: Read->*, *->Write. + + return from == XlaResourceOpKind::kRead || to == XlaResourceOpKind::kWrite; } using ResourceOp = std::pair; diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index e54b547abcfea698fe79e81dce547ea7858ff829..67304412fd384edde931fa2c5efb05f49e10411f 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -130,9 +130,7 @@ TEST(ResourceOperationSafetyAnalysisTest, ReadModify) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - EXPECT_EQ(incompatible_pairs.size(), 1); - std::pair read_modify_pair = {read->id(), modify->id()}; - EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs.size(), 0); } TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) { @@ -162,9 +160,7 @@ TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - EXPECT_EQ(incompatible_pairs.size(), 1); - std::pair modify_write_pair = {modify->id(), write->id()}; - EXPECT_EQ(incompatible_pairs[0], modify_write_pair); + EXPECT_EQ(incompatible_pairs.size(), 0); } TEST(ResourceOperationSafetyAnalysisTest, WriteModify) { @@ -196,11 +192,7 @@ TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - EXPECT_EQ(incompatible_pairs.size(), 2); - std::pair modify_write_pair = {modify->id(), write->id()}; - std::pair read_modify_pair = {read->id(), modify->id()}; - EXPECT_EQ(incompatible_pairs[0], read_modify_pair); - EXPECT_EQ(incompatible_pairs[1], modify_write_pair); + EXPECT_EQ(incompatible_pairs.size(), 0); } TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) { @@ -239,14 +231,12 @@ TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - ASSERT_EQ(incompatible_pairs.size(), 3); + ASSERT_EQ(incompatible_pairs.size(), 2); std::pair write_modify_pair = {write->id(), modify->id()}; std::pair write_read_pair = {write->id(), read->id()}; - std::pair read_modify_pair = {read->id(), modify->id()}; - EXPECT_EQ(incompatible_pairs[0], read_modify_pair); - EXPECT_EQ(incompatible_pairs[1], write_read_pair); - EXPECT_EQ(incompatible_pairs[2], write_modify_pair); + EXPECT_EQ(incompatible_pairs[0], write_read_pair); + EXPECT_EQ(incompatible_pairs[1], write_modify_pair); } FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { @@ -307,9 +297,7 @@ TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - ASSERT_EQ(incompatible_pairs.size(), 1); - std::pair read_call_edge = {read->id(), call->id()}; - EXPECT_EQ(incompatible_pairs[0], read_call_edge); + EXPECT_EQ(incompatible_pairs.size(), 0); } TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { @@ -329,9 +317,7 @@ TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - ASSERT_EQ(incompatible_pairs.size(), 1); - std::pair call_write_edge = {call->id(), write->id()}; - EXPECT_EQ(incompatible_pairs[0], call_write_edge); + EXPECT_EQ(incompatible_pairs.size(), 0); } TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { @@ -429,18 +415,14 @@ TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) { std::vector> incompatible_pairs; TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); - ASSERT_EQ(incompatible_pairs.size(), 5); + ASSERT_EQ(incompatible_pairs.size(), 3); std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; - std::pair write_0_write_1_pair = {write_0->id(), write_1->id()}; - std::pair read_0_read_1_pair = {read_0->id(), read_1->id()}; EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); - EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair); - EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair); - EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair); - EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_1_pair); } TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) { diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..80c691fe490c1092315708a2da754d367d585300 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +namespace { + +// Converts a shape inference handle to a PartialTensorShape. +Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, + const shape_inference::ShapeHandle& handle, + PartialTensorShape* shape) { + // The default is already unknown + if (!context->RankKnown(handle)) return Status::OK(); + + std::vector dims(context->Rank(handle)); + for (int32 i = 0; i < dims.size(); ++i) { + dims[i] = context->Value(context->Dim(handle, i)); + } + return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); +} + +Status PropagateShapes(const Graph& graph, + const std::map& arg_shapes, + ShapeRefiner* shape_refiner) { + // Visits the nodes in topological order (reverse post-order), inferring + // shapes. + // TODO(phawkins): handle cyclic graphs. + std::vector order; + GetReversePostOrder(graph, &order); + + for (Node* n : order) { + // Ignore the status returned by the shape_refiner. We want the best effort + // shapes, even if no shape function is registered for a node. + Status status = shape_refiner->AddNode(n); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + + if (n->type_string() == "_Arg") { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + auto it = arg_shapes.find(index); + if (it != arg_shapes.end()) { + const InferredShape& arg_shape = it->second; + shape_inference::InferenceContext* context = + shape_refiner->GetContext(n); + + if (arg_shape.handle_type != DT_INVALID) { + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape( + arg_shape.handle_shape, &handle)); + + // Sets the shape and type of the variable's value. + context->set_output_handle_shapes_and_types( + 0, std::vector{ + {handle, arg_shape.handle_type}}); + } + + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle)); + TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle)); + } + } + } + return Status::OK(); +} + +// Store the shapes of the output tensors in a map +Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, + GraphShapeInfo* shape_info) { + for (const Node* node : graph.nodes()) { + shape_inference::InferenceContext* context = shape_refiner.GetContext(node); + if (!context) continue; + + auto& outputs = (*shape_info)[node->name()]; + outputs.resize(context->num_outputs()); + for (int i = 0; i < context->num_outputs(); ++i) { + auto& output = outputs[i]; + TF_RETURN_IF_ERROR( + ShapeHandleToTensorShape(context, context->output(i), &output.shape)); + + const auto* handle_shapes_and_types = + context->output_handle_shapes_and_types(i); + if (handle_shapes_and_types != nullptr) { + if (handle_shapes_and_types->size() == 1) { + TF_RETURN_IF_ERROR(ShapeHandleToTensorShape( + context, (*handle_shapes_and_types)[0].shape, + &output.handle_shape)); + output.handle_type = (*handle_shapes_and_types)[0].dtype; + } else { + // otherwise, it may be resource like a Queue, which can have + // multiple shapes and types represented by a single handle. + } + } + VLOG(4) << node->name() << " output " << i << " shape" + << output.shape.DebugString() << " handle_type " + << DataTypeString(output.handle_type) << " handle_shape " + << output.handle_shape.DebugString(); + } + } + return Status::OK(); +} + +} // namespace + +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info) { + ShapeRefiner shape_refiner(graph->versions(), graph->op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + // TODO(dlibenzi): Verify if it is worth trying to infer shaped within + // functions. Some functions can be called at multiple locations with + // difference shapes, which will trigger a shape inference based on the + // arguments passed at the first call. + // shape_refiner.set_function_library_for_shape_inference(fnlib_def); + + // ShapeRefiner requires that all inputs of a node are present when + // ShapeRefiner::AddNode is called. To get at least some shape information in + // loops, we temporarily remove loop backedges and add them back again after + // the shape inference is complete. + BackEdgeHelper back_edge; + TF_RETURN_IF_ERROR(back_edge.Remove(graph)); + TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner)); + TF_RETURN_IF_ERROR(back_edge.Replace()); + + // Currently information does not flow "backward" from consumers to producers + // in the shape inference, but we consume the shapes in a second pass in case + // backward information flow is added in the future. + return StoreOutputShapes(*graph, shape_refiner, shape_info); +} + +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b) { + InferredShape result; + TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape)); + + if (a.handle_type == DT_INVALID) { + result.handle_type = b.handle_type; + } else if (b.handle_type == DT_INVALID) { + result.handle_type = a.handle_type; + } else if (a.handle_type == b.handle_type) { + result.handle_type = a.handle_type; + } else { + return errors::InvalidArgument( + "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ", + DataTypeString(b.handle_type)); + } + TF_RETURN_IF_ERROR( + a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape)); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..8668dbca55c2cf84729d81086bde45757e54f8ab --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.h @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +struct InferredShape { + // Shape of the argument tensor. + PartialTensorShape shape; + + // If the argument is a resource variable, the type and shape of the + // variable's value. + DataType handle_type = DT_INVALID; + PartialTensorShape handle_shape; +}; +typedef std::unordered_map> GraphShapeInfo; + +// Infer shapes for all Tensors in a graph, and save them in a map. The vector +// for a Node contains the information about each of its outputs. +// TODO(phawkins): this code does not infer accurate shapes for cyclic graphs. +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info); + +// Merges two InferredShapes. Return an error if the two shapes cannot be +// merged. +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9268172b1c4a4a717b608a52041219d54383a3ff --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for ShapeInference. + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(ShapeInferenceTest, Basics) { + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT, + ops::Placeholder::Shape({2, 3})); + auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT, + ops::Placeholder::Shape({3})); + auto c = ops::Placeholder(root.WithOpName("C"), DT_FLOAT); + auto d = ops::Add(root.WithOpName("D"), a, b); + auto e = ops::Add(root.WithOpName("E"), d, c); + auto f = ops::Neg(root.WithOpName("F"), e); + auto g = ops::AddN(root.WithOpName("G"), std::initializer_list{e, f}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(root.ToGraph(graph.get())); + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{}, + /*fnlib_def=*/nullptr, &shape_info)); + + std::map> expected = { + {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}}, + {"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}}, + {"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}}, + {"G", {PartialTensorShape()}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(*graph, shape_info, expected)); +} + +TEST(ShapeInferenceTest, WhileLoop) { + // Graph: + // x = array_ops.placeholder(dtypes.int32) + // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32, + ops::Placeholder::Shape({})); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32, + ops::Placeholder::Shape({})); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + // Add an unused Enter node. These should be ignored. + auto enter2 = + ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_node = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_node.output_false); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), + switch_node.output_true); + auto identity_shape = + ops::Const(scope.WithOpName("while/Identity/shape"), {}); + auto identity_reshaped = ops::Reshape( + scope.WithOpName("while/Identity/reshaped"), identity, identity_shape); + + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, + &shape_info)); + std::map> expected = { + {"while/Identity", {PartialTensorShape()}}, + {"while/add", {PartialTensorShape({})}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..cada272090a1f613baea8f6d111866d8bb9cd55b --- /dev/null +++ b/tensorflow/compiler/jit/test_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/test_util.h" + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes) { + for (Node* node : graph.op_nodes()) { + auto sit = shape_info.find(node->name()); + TF_RET_CHECK(sit != shape_info.end()) + << "Missing shape information for node " << node->name(); + std::vector shapes; + for (const auto& output : sit->second) shapes.push_back(output.shape); + + auto it = expected_shapes.find(node->name()); + if (it != expected_shapes.end()) { + if (!PartialTensorShapeUtils::AreIdentical(shapes, it->second)) { + return errors::InvalidArgument( + "Shape mismatch for ", node->name(), ". Expected: ", + PartialTensorShapeUtils::PartialShapeListString(it->second), + ", actual: ", + PartialTensorShapeUtils::PartialShapeListString(shapes)); + } + expected_shapes.erase(it); + } + } + if (!expected_shapes.empty()) { + std::vector missing; + missing.reserve(expected_shapes.size()); + for (const auto& entry : expected_shapes) { + missing.push_back(entry.first); + } + return errors::InvalidArgument("Missing shapes for nodes: ", + str_util::Join(missing, ",")); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0c9fee8f2446d41f792a6cfbf8fc808d9d679c09 --- /dev/null +++ b/tensorflow/compiler/jit/test_util.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Tests that the shapes in 'shape_info' for the nodes in `graph` match +// `expected_shapes`. Returns an error if there are nodes in `expected_shapes` +// that do not have shape information. Ignores nodes in `graph` that do not have +// `expected_shapes` entries. +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes); + +} // namespace tensorflow + + +#endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index f85121ca27ad3da918315f93b28e9000dfd65e67..fef28fc810cb4e544fe3f271f0b96cebd8a96779 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -28,6 +28,8 @@ namespace tensorflow { const char* const kXlaClusterAttr = "_XlaCluster"; const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; +const char* const kXlaCompileTimeConstantInputsAttr = + "_XlaCompileTimeConstantInputs"; namespace { // Returns a string describing how an edge from src to dst would diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index ba218f3315d2607c47342fdade0403678faa2362..fa6eaab3900b37baf7271c8c431c8384ceeda59f 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -32,6 +32,15 @@ extern const char* const kXlaClusterAttr; // compilation by the encapsulate subgraphs pass. extern const char* const kXlaOutsideCompilationAttr; +// The attribute that marks certain inputs to a Node as required to be a +// constant at compile time. If this attribute is present then the +// CompileTimeConstantInput information in the corresponding XlaOpKernel is +// ignored. +// +// The value for this attribute, if present, has to be a list of strings naming +// the inputs to the node that must be constant. +extern const char* const kXlaCompileTimeConstantInputsAttr; + using OrderedNodeSet = std::set; // Returns the DeviceType corresponding to 'device'. diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 0471995015bb080016b523305c90a3e42163a039..3df5479a55e841380ca7b8cdd0add9fd17487091 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -40,6 +41,7 @@ namespace tensorflow { XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) : client_(client), device_type_(std::move(device_type)) {} + XlaCompilationCache::~XlaCompilationCache() { // Ensure any use of our programs have completed by waiting for all stream // executors to complete. @@ -64,14 +66,14 @@ string XlaCompilationCache::DebugString() { // Compute a string signature which encodes the shapes of the // arguments in the supplied list. -string XlaCompilationCache::SignatureDebugString(const Signature& sig) { - string result = sig.name; - for (const auto& a : sig.arg_types) { +string XlaCompilationCache::Signature::HumanString() const { + string result = name; + for (const auto& a : arg_types) { absl::StrAppend(&result, ",", DataTypeString(a.first), a.second.DebugString()); } - for (const auto& v : sig.arg_values) { + for (const auto& v : arg_values) { absl::StrAppend(&result, "; ", v.DebugString()); } return result; @@ -83,7 +85,9 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_values.size() != other.arg_values.size()) return false; for (int i = 0; i < arg_values.size(); ++i) { - if (arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { + if (arg_values[i].dtype() != other.arg_values[i].dtype() || + arg_values[i].shape() != other.arg_values[i].shape() || + arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { return false; } } @@ -107,96 +111,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( return h; } -Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - Signature* signature) { - signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.reserve(constant_args.size()); - - signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); - - for (int i = 0; i < ctx->num_inputs(); ++i) { - if (constant_args.count(i) > 0) { - // Use the values of compile time constants in the signature. - signature->arg_values.push_back(constant_args.at(i)); - } else if (variable_args.count(i) > 0) { - const OptionalTensor& variable = variable_args.at(i); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); - } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); - } - } else { - signature->arg_types.emplace_back(ctx->input_dtype(i), - ctx->input(i).shape()); - } - } - return Status::OK(); -} - -namespace { - -// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. -Status BuildArguments(const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, - std::vector* args) { - args->resize(ctx->num_inputs()); - - for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { - XlaCompiler::Argument& arg = (*args)[input_num]; - if (constant_args.count(input_num) > 0) { - // Handles compile-time constants. - const Tensor& input = constant_args.at(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - } else if (variable_args.count(input_num) == 0) { - // Handles the non-constant arguments. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - } else { - // Handles resource variables. - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - const OptionalTensor& variable = variable_args.at(input_num); - arg.name = variable.name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable.present) { - const Tensor& value = variable.value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; - } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do - // permit uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); - } +xla::StatusOr +XlaCompilationCache::BuildSignature( + const NameAttrList& function, + absl::Span args) { + Signature signature; + signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); + for (const XlaCompiler::Argument& arg : args) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + signature.arg_values.push_back(arg.constant_value); + break; + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kResource: + signature.arg_types.emplace_back(arg.type, arg.shape); + break; + default: + return errors::InvalidArgument( + "Unhandled argument kind in XlaCompilationCache: ", + arg.HumanString()); } } - - return Status::OK(); + return std::move(signature); } -} // namespace - Status XlaCompilationCache::BuildExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& result, @@ -226,20 +164,38 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { - return CompileImpl(options, function, constant_args, variable_args, ctx, - compile_options, /*compile_single_op=*/false, + absl::optional compile_threshold; + if (compile_mode == CompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } + auto compile_fn = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + return compiler->CompileFunction(compile_options, function, args, result); + }; + return CompileImpl(options, function, args, compile_fn, + /*compile_threshold=*/compile_threshold, out_compilation_result, out_executable); } +static bool IsMegamorphic(int64 compile_count, int64 execution_count) { + const int64 kCompileThreshold = 10; + const int64 kMinExecutionsPerCompile = 50; + + // This heuristic is trying to capture the following property: have we sunk a + // certain minimum amount of compile time into the cluster that didn't quite + // "pay off"? + return compile_count > kCompileThreshold && + execution_count < kMinExecutionsPerCompile * compile_count; +} + Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { @@ -247,52 +203,41 @@ Status XlaCompilationCache::CompileSingleOp( NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl( - options, name, constant_args, variable_args, ctx, compile_options, - /*compile_single_op=*/true, out_compilation_result, out_executable); + auto compile_op = [&](XlaCompiler* compiler, + XlaCompiler::CompilationResult* result) { + std::vector result_dtypes(ctx->num_outputs()); + for (int i = 0; i < result_dtypes.size(); ++i) { + result_dtypes[i] = ctx->expected_output_dtype(i); + } + return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(), + args, result_dtypes, result); + }; + return CompileImpl(options, name, args, compile_op, + /*compile_threshold=*/absl::nullopt, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { - VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << constant_args.size() - << " num_variable_args=" << variable_args.size(); - for (int i = 0; i < ctx->num_inputs(); i++) { - TensorShape shape = ctx->input(i).shape(); - VLOG(2) << i << ": dtype=" << DataTypeString(ctx->input_dtype(i)) - << " present=" << ctx->has_input(i) - << " shape=" << shape.DebugString(); - } - for (auto& iterator : variable_args) { - const OptionalTensor& variable = iterator.second; - VLOG(2) << "variable present=" << variable.present - << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString() - << " TF arg= " << iterator.first; - } - VLOG(2) << "num_outputs = " << ctx->num_outputs(); - for (int i = 0; i < ctx->num_outputs(); i++) { - VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i); + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0; i < args.size(); i++) { + VLOG(2) << i << ": " << args[i].HumanString(); } } - TF_RET_CHECK(constant_args.size() + variable_args.size() <= - ctx->num_inputs()); - - Signature signature; - TF_RETURN_IF_ERROR( - BuildSignature(function, constant_args, variable_args, ctx, &signature)); + TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); + VLOG(2) << "Signature: " << signature.HumanString(); - VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not // protect the contents of the cache entry. Entry* entry; @@ -306,32 +251,87 @@ Status XlaCompilationCache::CompileImpl( entry = e.get(); } + // We always compile a cluster the very first time it is executed. This is an + // optimistic guess that pays off for statically shaped TensorFlow graphs + // (since they get the benefit of XLA right away without waiting for warmup) + // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at + // most one cluster-compilation's worth of compile time). + bool is_first_execution; + + // We avoid compiling clusters that have "gone megamorphic" i.e. have an + // excessive amount of shape dynamism. + bool is_megamorphic; + + { + mutex_lock lock(cluster_compile_stats_mu_); + auto it = + cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) + .first; + is_first_execution = it->second.execution_count++ == 0; + + // The is_megamorphic bit is "sticky". We assume clusters that have been + // observed to be megamorphic once stay megamorphic forever. + it->second.is_megamorphic |= + IsMegamorphic(/*compile_count=*/it->second.compile_count, + /*execution_count=*/it->second.execution_count); + is_megamorphic = it->second.is_megamorphic; + } + // Acquire the cache entry lock and compile, if necessary. // TODO(phawkins): this locking will need to be restructured when we implement // cache eviction. mutex_lock entry_lock(entry->mu); + int64 current_request_count = ++entry->request_count; + VLOG(2) << "Compilation cache entry hit: " << entry->compiled + << " signature: " << signature.HumanString() << " with request count " + << current_request_count << " and compile threshold " + << compile_threshold.value_or(0); if (!entry->compiled) { - VLOG(2) << "Compilation cache miss for signature: " - << SignatureDebugString(signature); + const bool should_compile = [&] { + if (!compile_threshold.has_value()) { + // Lazy compilation is disabled. + return true; + } + + if (is_megamorphic) { + VLOG(3) << "Not compiling cluster " << function.name() + << " because it is megamorphic."; + return false; + } + + if (is_first_execution) { + return true; + } + + bool reached_compile_threshold = + current_request_count >= *compile_threshold; + if (!reached_compile_threshold) { + VLOG(3) + << "Not compiling cluster " << function.name() + << " because it has not reached compile threshold; threshold is " + << *compile_threshold << " execution count " + << current_request_count << "."; + } + return reached_compile_threshold; + }(); + + if (!should_compile) { + VLOG(2) << "Not compiling for signature: " << signature.HumanString(); + *out_compilation_result = nullptr; + *out_executable = nullptr; + return Status::OK(); + } + tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); // Do the actual JIT compilation without holding the lock (it can take // a long time.) - std::vector args; - TF_RETURN_IF_ERROR( - BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - if (compile_single_op) { - entry->compilation_status = - compiler.CompileSingleOp(compile_options, signature.name, ctx, args, - &entry->compilation_result); - } else { - entry->compilation_status = compiler.CompileFunction( - compile_options, function, args, &entry->compilation_result); - } + entry->compilation_status = + compile_fn(&compiler, &entry->compilation_result); TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); entry->compilation_status = @@ -340,8 +340,8 @@ Status XlaCompilationCache::CompileImpl( const uint64 compile_end_us = env->NowMicros(); const uint64 compile_time_us = compile_end_us - compile_start_us; { - mutex_lock lock(compile_stats_mu_); - auto it = compile_stats_.emplace(function.name(), CompileStats{}).first; + mutex_lock lock(cluster_compile_stats_mu_); + auto it = cluster_compile_stats_.find(function.name()); it->second.compile_count++; it->second.cumulative_compile_time_us += compile_time_us; VLOG(1) << "compiled " << function.name() << " " diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 75c7758f730f9f2f8251c02e7fac1a01f8cc9c2b..846d0c963dbfdf55f51120f2f138d12f5f63839b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -17,9 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/graph.pb.h" @@ -30,13 +33,6 @@ limitations under the License. namespace tensorflow { -// Struct that represents a possibly-absent Tensor. -struct OptionalTensor { - string name; // A descriptive name - bool present = false; // Is the tensor present? - Tensor value; // If present, what is the Tensor's value? -}; - // The XlaCompilationCache class caches the results of the XlaCompiler class, // which converts a Tensorflow graph into a compiled XLA compilation. // @@ -50,14 +46,23 @@ class XlaCompilationCache : public ResourceBase { XlaCompilationCache(xla::LocalClient* client, DeviceType device_type); ~XlaCompilationCache() override; + enum class CompileMode { + kLazy, + kStrict, + }; + // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `constant_args` is a map of tensorflow argument number to its constant - // value. - // `variable_args` is a snapshot of the current values of the - // resource variable arguments to `function`; uninitialized variables are - // represented by an absent OptionalTensor. + // `args` is a description of the arguments to the computation. + // + // `compile_mode` controls the behavior of the compilation cache on a cache + // miss. If `compile_mode` is `kLazy` then, based on some profitability + // heuristics, the compilation cache may decide not to compile the cluster at + // this time. In this case it returns null into both `out_compilation_result` + // and `out_executable`. If `compile_mode` is `kStrict` then the compilation + // cache always attempts the compilation on a cache miss. + // // The result of compilation is written to `*compilation_result`, which must // be non-null. If `executable` is non-null, also builds an // xla::LocalExecutable and sets `executable` to point to it. The resulting @@ -65,10 +70,9 @@ class XlaCompilationCache : public ResourceBase { // outputs. Status Compile(const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, + absl::Span args, const XlaCompiler::CompileOptions& compile_options, + CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -76,8 +80,7 @@ class XlaCompilationCache : public ResourceBase { // XlaCompiler::CompileFunction. Status CompileSingleOp( const XlaCompiler::Options& options, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -87,26 +90,6 @@ class XlaCompilationCache : public ResourceBase { string DebugString() override; - private: - // Common implementation of Compile and CompileSingleOp. - Status CompileImpl( - const XlaCompiler::Options& options, const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // Takes `result` which has been compiled from a Tensorflow subgraph to a - // XLA computation already, and generates an XLA LocalExecutable `executable`. - Status BuildExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable); - - xla::LocalClient* const client_; - const DeviceType device_type_; - // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. struct Signature { @@ -123,14 +106,35 @@ class XlaCompilationCache : public ResourceBase { struct Hash { uint64 operator()(const Signature& signature) const; }; + + // Returns a human-readable description of the signature. + string HumanString() const; }; - static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, - const std::map& constant_args, - const std::map& variable_args, - OpKernelContext* ctx, Signature* signature); + static xla::StatusOr BuildSignature( + const NameAttrList& function, + absl::Span args); + + private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + absl::Span args, + const std::function& compile_fn, + absl::optional compile_threshold, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); + + // Takes `result` which has been compiled from a Tensorflow subgraph to a + // XLA computation already, and generates an XLA LocalExecutable `executable`. + Status BuildExecutable(const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + std::unique_ptr* executable); + + xla::LocalClient* const client_; + const DeviceType device_type_; // The value associated with a cache entry. struct Entry { @@ -139,6 +143,9 @@ class XlaCompilationCache : public ResourceBase { // Have we tried compiling this entry? bool compiled = false; + // The number of times a compilation with this signature has been requested. + int64 request_count = 0; + // Did compilation succeed? Status compilation_status GUARDED_BY(mu); @@ -154,18 +161,31 @@ class XlaCompilationCache : public ResourceBase { absl::flat_hash_map, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); - struct CompileStats { + struct ClusterCompileStats { // Number of times the cluster has been (re-)compiled. int64 compile_count = 0; + // The number of times this cluster has been executed. + int64 execution_count = 0; + // Cumulative time spent compiling the cluster. int64 cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; }; - mutex compile_stats_mu_; + + mutex cluster_compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - absl::flat_hash_map compile_stats_ - GUARDED_BY(compile_stats_mu_); + absl::flat_hash_map cluster_compile_stats_ + GUARDED_BY(cluster_compile_stats_mu_); + + // The number of times a lazy compilation must be requested for a specific + // signature before we attempt to compile it. + static constexpr int64 kDefaultCompilationThreshold = 2; TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); }; diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..018c7c219f445bdca17f4f8b060e3678fe1be9ee --- /dev/null +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(XlaCompilationCacheTest, SignatureEquality) { + NameAttrList fn; + fn.set_name("afunction"); + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({4, 0}); + args[0].constant_value = Tensor(DT_INT32, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].type = DT_FLOAT; + args[0].constant_value = Tensor(DT_FLOAT, {4, 0}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, + XlaCompilationCache::BuildSignature(fn, args)); + + args[0].shape = TensorShape({0, 4}); + args[0].constant_value = Tensor(DT_FLOAT, {0, 4}); + TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3, + XlaCompilationCache::BuildSignature(fn, args)); + + std::vector signatures = {s1, s2, s3}; + for (int i = 0; i < signatures.size(); ++i) { + for (int j = 0; j < signatures.size(); ++j) { + EXPECT_EQ(i == j, signatures[i] == signatures[j]) + << signatures[i].HumanString() << " " << signatures[j].HumanString(); + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 79976c85dff200ce993ebb06e7a20a15b71f6085..c7e8d61d280a33a83c3386d8ef801018634d31ec 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -16,6 +16,8 @@ limitations under the License. // Defines the XlaCompileOnDemandOp. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" + +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -86,29 +88,26 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, return Status::OK(); } -bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, - int64 argument_idx) { +Status XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx, + bool* result) { + *result = false; + // TODO(jmolloy): This could be expensive, so memoize. - auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( - op_kernel->def().op()); - CHECK(constant_inputs); - std::set constant_input_indices; - for (const auto& name : *constant_inputs) { - int start, stop; - TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop)); - for (int i = start; i < stop; ++i) { - constant_input_indices.insert(i); - } - } - return constant_input_indices.count(argument_idx) > 0; + std::vector constant_input_indices; + TF_RETURN_IF_ERROR(XlaOpRegistry::CompileTimeConstantInputs( + *op_kernel, &constant_input_indices)); + *result = absl::c_binary_search(constant_input_indices, argument_idx); + return Status::OK(); } -bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, - int64 argument_idx) { +Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx, + bool* result) { // Right now we only create kConstant arguments when absolutely required, but // there may be benefit in eagerly constant-folding a larger subset of // arguments in the future. - return MustArgumentBeConstant(op_kernel, argument_idx); + return MustArgumentBeConstant(op_kernel, argument_idx, result); } Status XlaCompileOnDemandOp::Compile( @@ -119,27 +118,48 @@ Status XlaCompileOnDemandOp::Compile( for (int64 i = 0; i < ctx->num_inputs(); ++i) { const Tensor& device_tensor = ctx->input(i); if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { - if (xla_tensor->has_host_tensor() && - ShouldArgumentBeConstant(&ctx->op_kernel(), i)) { - constant_arguments[i] = xla_tensor->host_tensor(); + if (xla_tensor->has_host_tensor()) { + bool should_arg_be_const; + TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i, + &should_arg_be_const)); + if (should_arg_be_const) { + constant_arguments[i] = xla_tensor->host_tensor(); + } } } - if (constant_arguments.count(i) == 0 && - MustArgumentBeConstant(&ctx->op_kernel(), i)) { - // Slow path; the argument is not available as a host constant so we must - // fetch it synchronously. - Tensor host_tensor; - AllocatorAttributes attrs; - attrs.set_on_host(true); - TF_RETURN_IF_ERROR(ctx->allocate_temp( - device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); - Notification n; - ctx->op_device_context()->CopyDeviceTensorToCPU( - &device_tensor, "ConstantArgument", - reinterpret_cast(ctx->device()), &host_tensor, - [&](Status status) { n.Notify(); }); - n.WaitForNotification(); - constant_arguments[i] = host_tensor; + + if (constant_arguments.count(i) == 0) { + bool must_argument_be_const; + TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i, + &must_argument_be_const)); + + if (must_argument_be_const) { + // Slow path; the argument is not available as a host constant so we + // must fetch it synchronously. + Tensor host_tensor; + AllocatorAttributes attrs; + attrs.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp( + device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); + Notification n; + Status status; + ctx->op_device_context()->CopyDeviceTensorToCPU( + &device_tensor, "ConstantArgument", + reinterpret_cast(ctx->device()), &host_tensor, + [&](Status s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + if (!status.ok()) { + LOG(ERROR) << "Copying tensor of shape " + << device_tensor.shape().DebugString() << " from " + << ctx->device()->name() << "to CPU failed with " + << status.ToString(); + return status; + } + constant_arguments[i] = host_tensor; + } } } @@ -164,8 +184,7 @@ Status XlaCompileOnDemandOp::Compile( XlaCompiler::Options options; options.device_type = metadata.jit_device_type(); options.client = metadata.client(); - options.flib_def = - new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.shape_representation_fn = metadata.shape_representation_fn(); XlaCompiler::CompileOptions compile_options; @@ -179,8 +198,14 @@ Status XlaCompileOnDemandOp::Compile( compile_options.always_return_tuple = false; std::map variable_args = GetVariables(ctx); - return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - compile_options, result, executable); + + std::vector args; + + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arguments, variable_args, ctx, &args)); + + return cache->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 7cc3d0e007ba2974fbfbe6fbabc4aa08f9fa910f..b93bb15ce34688f26316e22bf59f448e787df9fc 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -38,8 +38,10 @@ class XlaCompileOnDemandOp : public OpKernel { private: XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); - bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); - bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + Status ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, + bool* result); + Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, + bool* result); Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, const XlaCompiler::CompilationResult** result, xla::LocalExecutable** executable); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 003c1d8081a3313fd042cdcaea14508ed1048da3..e9770647e7ba96cc1db026d12d5f11f52ce98d35 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,8 +16,9 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -30,34 +31,51 @@ namespace tensorflow { class XlaCpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, - const string& name_prefix, - std::vector* devices) { - legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); +Status XlaCpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { + XlaDeviceFlags* flags = GetXlaDeviceFlags(); bool compile_on_demand = flags->tf_xla_compile_on_demand; XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = !compile_on_demand; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + compile_on_demand + ? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested + : XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration); static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); (void)registrations; - std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - registration, - /*transfer_as_literal=*/false, - /*use_multiple_streams=*/false, - /*shape_representation_fn=*/{}, - /*padded_shape_fn=*/{}, &device)); - devices->push_back(device.release()); + TF_ASSIGN_OR_RETURN(auto platform, + se::MultiPlatformManager::PlatformWithName("Host")); + + XlaDevice::Options options; + options.platform = platform; + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_CPU; + options.device_ordinal = 0; + options.compilation_device_name = DEVICE_CPU_XLA_JIT; + options.use_multiple_streams = false; + auto device = absl::make_unique(session_options, options); + + // Setting GpuDeviceInfo because eager runtime relies on the device + // context in tensorflow_gpu_device_info(). Also, + // tensorflow_gpu_device_info() == nullptr is used as an IsCPU test. + // We need XlaCpuDevice to be treated not as CPU because it allocates + // XlaTensors, not regular Tensors. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); + return status; + } + devices->push_back(std::move(device)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 0824c4644e3e5d8e1390b99f12de824bfcdfec24..4201ff91a89b1bee370e6a43337c51abe3bf974a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -125,41 +125,17 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { return Status::OK(); } -} // namespace - -/* static */ Status XlaDevice::Create( - const string& platform_name, const string& device_name, int device_ordinal, - const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, bool use_multiple_streams, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { - VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" - << device_ordinal; - - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); - - auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); - if (!platform.ok()) { - return platform.status(); - } - - const DeviceAttributes attrs = Device::BuildDeviceAttributes( +static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix, + const string& device_name, + int device_ordinal) { + return Device::BuildDeviceAttributes( absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), absl::StrCat("device: ", device_name, " device")); - - device->reset( - new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, - use_multiple_streams, shape_representation_fn, - padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); - return Status::OK(); } +} // namespace + XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, XlaCompiler::ShapeRepresentationFn shape_representation_fn, @@ -209,30 +185,42 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return GetMetadataFromDevice(ctx->device(), metadata); } -XlaDevice::XlaDevice( - const SessionOptions& options, const DeviceAttributes& attrs, - int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - const PaddedShapeFn& padded_shape_fn) - : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn, padded_shape_fn, - use_multiple_streams), - device_ordinal_(device_ordinal), - jit_device_name_(jit_device_name), - platform_(platform), - use_multiple_streams_(use_multiple_streams), - transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(shape_representation_fn) { - VLOG(1) << "Created XLA device " << jit_device_name << " " << this; - thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", +XlaDevice::XlaDevice(const SessionOptions& session_options, + const Options& options) + : LocalDevice(session_options, + BuildXlaDeviceAttributes(options.device_name_prefix, + options.device_name, + options.device_ordinal)), + xla_metadata_(options.device_ordinal, options.platform, + DeviceType(options.compilation_device_name), + options.shape_representation_fn, + options.padded_shape_fn ? options.padded_shape_fn + : DefaultPaddedShapeFn, + options.use_multiple_streams), + device_ordinal_(options.device_ordinal), + jit_device_name_(options.compilation_device_name), + platform_(options.platform), + use_multiple_streams_(options.use_multiple_streams), + shape_representation_fn_(options.shape_representation_fn) { + VLOG(1) << "Created XLA device " << options.compilation_device_name << " " + << this; + thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device", /*num_threads=*/1)); + + // We have multiple device to device streams to allow for some concurrency + // between transfers. The particular value of '4' is chosen fairly + // arbitrarily. It may be necessary to make this tunable via + // XlaDevice::Options. + static constexpr int kNumDeviceToDeviceStreams = 4; + device_to_device_streams_.resize(kNumDeviceToDeviceStreams); } XlaDevice::~XlaDevice() { VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } if (device_context_) { device_context_->Unref(); } @@ -295,8 +283,9 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - std::shared_ptr host_to_device_stream = stream_; - std::shared_ptr device_to_host_stream = stream_; + std::shared_ptr host_to_device_stream; + std::shared_ptr device_to_host_stream; + std::vector> device_to_device_streams; if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -304,8 +293,18 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); + for (std::shared_ptr& stream : device_to_device_streams_) { + TF_RETURN_IF_ERROR( + EnsureStreamOkLocked(backend, "device_to_device_stream", &stream, + &need_new_device_context)); + } host_to_device_stream = host_to_device_stream_; device_to_host_stream = device_to_host_stream_; + device_to_device_streams = device_to_device_streams_; + } else { + host_to_device_stream = stream_; + device_to_host_stream = stream_; + device_to_device_streams = {stream_}; } if (!need_new_device_context) { @@ -323,8 +322,9 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { // ensures that the streams remain live for the duration of a run, even if // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_, host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); + stream_, std::move(host_to_device_stream), + std::move(device_to_host_stream), std::move(device_to_device_streams), + client(), shape_representation_fn_, thread_pool_.get()); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; @@ -387,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; + tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true); std::shared_ptr stream; { mutex_lock lock(mu_); @@ -394,13 +395,46 @@ Status XlaDevice::Sync() { } if (!stream) return Status::OK(); - if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + Status status = stream->BlockHostUntilDone(); + { + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + } + TF_RETURN_IF_ERROR(status); + if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); } VLOG(1) << "XlaDevice::Sync completed"; return Status::OK(); } +void XlaDevice::Sync(const DoneCallback& done) { + VLOG(1) << "XlaDevice::Sync (asynchronous)"; + std::shared_ptr stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) { + done(Status::OK()); + return; + } + + stream->ThenEnqueueOnBackgroundThread( + [this, stream, done](se::StreamExecutor*) { + tracing::ScopedActivity activity("XlaDevice::Sync::Callback", + /*is_expensive=*/true); + mutex_lock lock(mu_); + while (outstanding_asynchronous_operations_ > 0) { + outstanding_asynchronous_operations_cv_.wait(lock); + } + done(stream->ok() ? Status::OK() + : errors::Internal("XlaDevice::Sync() failed.")); + }); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -444,12 +478,55 @@ bool XlaDevice::RequiresSyncOnCompletion() const { return sync_on_completion_; } +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice* device) + : device_(device) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() { + if (device_) { + mutex_lock lock(device_->mu_); + --device_->outstanding_asynchronous_operations_; + device_->outstanding_asynchronous_operations_cv_.notify_all(); + } +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + const XlaDevice::AsynchronousOperationHandle& other) + : device_(other.device_) { + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; +} + +XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle( + XlaDevice::AsynchronousOperationHandle&& other) + : device_(other.device_) { + other.device_ = nullptr; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(const XlaDevice::AsynchronousOperationHandle& other) { + device_ = other.device_; + mutex_lock lock(device_->mu_); + ++device_->outstanding_asynchronous_operations_; + return *this; +} + +XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle:: +operator=(XlaDevice::AsynchronousOperationHandle&& other) { + device_ = other.device_; + other.device_ = nullptr; + return *this; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes // it just-in-time. - kernel_factory::OpKernelRegistrar::Factory factory = + OpKernel* (*factory)(OpKernelConstruction*) = [](OpKernelConstruction* context) -> OpKernel* { return new XlaCompileOnDemandOp(context); }; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 0f06b3fc80b7c844dae5643127bdabba8a53b35e..c8bb276cdb9673fdcba4cc15a9f33ecd3ae96dbb 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -92,34 +92,41 @@ class XlaDevice : public LocalDevice { static Status GetMetadata(OpKernelConstruction* ctx, const Metadata** metadata); - // Factory function. 'platform_name' is the name of the XLA platform. - // 'device_name' is the name of the Tensorflow device to create. - // 'jit_device_name' is the name of the corresponding JIT device. - // 'transfer_as_literal' is true if device<->host transfers must be done using - // XLA's TransferLiteral{To,From}Device interface. If false, we can use - // ThenMemcpy instead. - // If 'use_multiple_streams' is true, we create separate streams for - // host-to-device and device-to-host communication. - // If padded_shape_fn is empty, a default implementation that returns - // the on-host shape is used. - static Status Create( - const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, bool use_multiple_streams, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + struct Options { + // The StreamExecutor platform. Not owned. Must be non-null. + se::Platform* platform = nullptr; + + // The device name's prefix (e.g., "/task:7") + string device_name_prefix; + + // The name of the XLA device (e.g., "XLA_CPU") + string device_name; + + // The number of the device. + int device_ordinal = -1; + + // The name of the compilation device (e.g., "XLA_CPU_JIT"); + string compilation_device_name; + + // If 'use_multiple_streams' is true, we create separate streams for + // compute, host-to-device, and device-to-host communication. + bool use_multiple_streams = false; + + // A function that describes how the on-host shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared + // shapes for computations. Must be non-null. + XlaCompiler::ShapeRepresentationFn shape_representation_fn; + + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. + PaddedShapeFn padded_shape_fn; + }; // Creates a new XLA Device. - // If padded_shape_fn is empty, a default implementation that returns - // the logical on-device shape without padding is used. - XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, - int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, - bool use_multiple_streams, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - const PaddedShapeFn& padded_shape_fn); + XlaDevice(const SessionOptions& session_options, const Options& options); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override @@ -128,6 +135,7 @@ class XlaDevice : public LocalDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; Status Sync() override; + void Sync(const DoneCallback& done) override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -157,7 +165,30 @@ class XlaDevice : public LocalDevice { bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + // A simple RAII handle. On construction the device's + // outstanding_asynchronous_operations_ field is incremented; on destruction + // it is decremented. + class AsynchronousOperationHandle { + public: + AsynchronousOperationHandle(XlaDevice* device); + ~AsynchronousOperationHandle(); + AsynchronousOperationHandle(const AsynchronousOperationHandle& other); + AsynchronousOperationHandle(AsynchronousOperationHandle&& other); + AsynchronousOperationHandle& operator=( + const AsynchronousOperationHandle& other); + AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other); + + private: + XlaDevice* device_ = nullptr; + }; + + AsynchronousOperationHandle CreateAsynchronousOperationHandle() { + return AsynchronousOperationHandle(this); + } + private: + friend class AsynchronousOperationHandle; + xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -182,6 +213,7 @@ class XlaDevice : public LocalDevice { se::Platform* const platform_; // Not owned. // Memory allocator associated with this device. Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. + // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -197,9 +229,11 @@ class XlaDevice : public LocalDevice { // If use_multiple_streams_, device to host transfers are performed using this // stream. std::shared_ptr device_to_host_stream_ GUARDED_BY(mu_); - // Must we use XLA's transfer manager for correct host<->device transfers? if - // false, we can use ThenMemcpy() instead. - const bool transfer_as_literal_; + // If use_multiple_streams_, transfers between different devices are performed + // using these streams. + std::vector> device_to_device_streams_ + GUARDED_BY(mu_); + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // The device context accessed by all users of the XlaDevice, set by calls to @@ -217,6 +251,11 @@ class XlaDevice : public LocalDevice { // True if the device requires XlaDevice::Sync to be called on completion // regardless of status. bool sync_on_completion_ GUARDED_BY(mu_) = false; + + // Count of outstanding asynchronous operations which must be zero on Sync() + // completion. + int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0; + condition_variable outstanding_asynchronous_operations_cv_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index af83c792e5e11d8596c521c6a3aed332a1f42e5b..6e6532731e64bd42ee56aa719748988f321e0f17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -50,94 +50,39 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager( +XlaDeviceContext::XlaDeviceContext( std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, + std::shared_ptr device_to_host_stream, + std::vector> device_to_device_streams, + xla::LocalClient* client, XlaCompiler::ShapeRepresentationFn shape_representation_fn, thread::ThreadPool* thread_pool) : stream_(std::move(compute_stream)), host_to_device_stream_(std::move(host_to_device_stream)), device_to_host_stream_(std::move(device_to_host_stream)), + device_to_device_streams_(std::move(device_to_device_streams)), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal), shape_representation_fn_(std::move(shape_representation_fn)), thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = - [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { return shape; }; + shape_representation_fn_ = [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } -Status XlaTransferManager::TransferLiteralToDevice( - const Tensor& host_tensor, Tensor* device_tensor) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), - host_tensor.shape(), &xla_shape)); - // Create a reference to hold onto host_tensor until after the literal has - // been transferred. Also make sure the literal exists until the function - // asynchronously completes, as it will be wrapped in an xla::LiteralSlice. - TensorReference ref(host_tensor); - auto literal = std::make_shared( - static_cast(DMAHelper::base(&host_tensor)), xla_shape); - - XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); - const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); - VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " - << shaped_buffer.ToString(); - if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( - stream_->parent(), shaped_buffer)) { - // Initially wait for the compute stream so that memory allocations are - // synchronized. - host_to_device_stream_->ThenWaitFor(stream_.get()); - } - TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_.get(), *literal, shaped_buffer)); - if (UseMultipleStreams()) { - auto event = std::make_shared(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(event.get()); - xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); - } - // Unref the host tensor, and capture the literal shared_ptr too so it goes - // out of scope when the lambda completes. - host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); }); - - return Status::OK(); -} - -void XlaTransferManager::TransferLiteralFromDevice( - Tensor* host_tensor, const Tensor& device_tensor, - const StatusCallback& done) const { - xla::MutableBorrowingLiteral literal; - TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(host_tensor, &literal)); - - const xla::ShapedBuffer& shaped_buffer = - XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - - TensorReference ref(device_tensor); - transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_.get(), shaped_buffer, literal, - [=, &shaped_buffer](xla::Status status) { - ref.Unref(); - done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " - << shaped_buffer.ToString(); - return status; - }()); - }); -} - -void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, - Device* device, - Tensor* device_tensor, - StatusCallback done) const { +void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done) const { if (cpu_tensor->NumElements() == 0) { VLOG(2) << "CopyCPUTensorToDevice empty tensor"; done(Status::OK()); @@ -152,61 +97,85 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); - void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); - const int64 total_bytes = cpu_tensor->TotalBytes(); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - xla::StatusOr shape_or_status = - shape_representation_fn_(device_tensor->shape(), device_tensor->dtype()); - if (!shape_or_status.ok()) { - done(shape_or_status.status()); - return; - } - TensorShape shape = shape_or_status.ValueOrDie(); - if (!xla_tensor->has_shaped_buffer()) { - Status s = + Status status = [&]() -> Status { + TF_ASSIGN_OR_RETURN(xla::Shape shape, + shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype())); + + // The device tensor should always be fresh. + TF_RET_CHECK(!xla_tensor->has_shaped_buffer()); + + xla_tensor->set_host_tensor(*cpu_tensor); + TF_RETURN_IF_ERROR( xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_, - stream_->parent()->device_ordinal()); - if (!s.ok()) { - done(s); - return; + stream_->parent()->device_ordinal())); + + // The cpu_tensor and literal that we created here hold the data of host + // tensor in descending layout. The layout could be different from layout in + // device_tensor (but the logical shape has to be the same). The + // transfer_manager is responsible to do corresponding transposing when + // transferring the data to device. + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(cpu_tensor)), + xla::ShapeUtil::MakeShape(shape.element_type(), + xla::AsInt64Slice(shape.dimensions()))); + + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << xla_tensor->shaped_buffer().ToString(); + if (UseMultipleStreams() && + !transfer_manager_->CanShapedBufferBeAccessedNow( + stream_->parent(), xla_tensor->shaped_buffer())) { + // Initially wait for the compute stream so that memory allocations are + // synchronized. + host_to_device_stream_->ThenWaitFor(stream_.get()); } - } - Status status; - if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = host_to_device_stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_.get(), block_status.error_message().c_str()); + TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( + host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer())); + + if (UseMultipleStreams()) { + auto event = std::make_shared(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->ResetDefinitionEvent(std::move(event), + host_to_device_stream_.get()); } + + return Status::OK(); + }(); + if (!status.ok()) { + done(status); + return; } - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); + + // Create a reference to hold onto cpu_tensor until after the literal has + // been transferred + TensorReference ref(*cpu_tensor); + if (UseMultipleStreams()) { + // Unref the host tensor when the transfer completes. + // We don't defer the call to done() onto the stream here, and the reasons + // why this is correct are subtle. We assume that: + // a) all consumers of the device tensor will wait for its definition event. + // b) if the tensor is destroyed, then the memory allocator will not hand + // out the same buffers until the transfer has completed. + host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); }); + done(status); + } else { + host_to_device_stream_->ThenDoHostCallback([ref, done]() { + ref.Unref(); + done(Status::OK()); + }); } - done(status); } -void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, - Device* device, - Tensor* cpu_tensor, - StatusCallback done) { +void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { if (device_tensor->NumElements() == 0) { VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; done(Status::OK()); @@ -220,136 +189,38 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, << cpu_tensor->shape().DebugString() << " " << device_tensor->shape().DebugString(); - const int64 total_bytes = cpu_tensor->TotalBytes(); - se::DeviceMemoryBase dev_src_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream_.get()); - if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { - device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_.get()); - } - - Status status; - if (transfer_as_literal_) { - TransferLiteralFromDevice(cpu_tensor, *device_tensor, done); - return; - } else { - device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = device_to_host_stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_.get(), - block_status.error_message().c_str()); - } - } - - done(status); -} - -void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, - Tensor* dst_tensor, - const StatusCallback& done) { - VLOG(2) << "CopyDeviceTensorToDevice " - << reinterpret_cast(src_tensor.tensor_data().data()) - << " " - << reinterpret_cast(dst_tensor->tensor_data().data()); - // Perform memory allocation now, and enqueue the device-to-device transfer. - Status status = [&]() -> Status { - if (src_tensor.NumElements() == 0) { - return Status::OK(); - } - // TODO(jmolloy): We co-opt the device_to_host stream for device to device - // transfers; perhaps we should have a dedicated device to device stream? or - // one per device? - auto device_to_device_stream = stream_; - XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); - XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); - CHECK(xla_src && xla_dst) - << "Missing destination tensor for device-to-device copy"; - if (!xla_dst->has_shaped_buffer()) { - TF_ASSIGN_OR_RETURN( - TensorShape shape, - shape_representation_fn_(src_tensor.shape(), src_tensor.dtype())); - TF_RETURN_IF_ERROR( - xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, - stream_->parent()->device_ordinal())); - if (stream_ != device_to_device_stream) { - // Initially wait for the compute stream so that memory allocations are - // synchronized. - device_to_device_stream->ThenWaitFor(stream_.get()); - } - } - - if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream.get())) { - device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream.get()); - } - - auto from_iter = xla_src->shaped_buffer().buffers().begin(); - auto to_iter = xla_dst->shaped_buffer().buffers().begin(); - for (auto end_iter = xla_src->shaped_buffer().buffers().end(); - from_iter != end_iter; ++from_iter, ++to_iter) { - device_to_device_stream->ThenMemcpyD2D( - &to_iter->second, from_iter->second, to_iter->second.size()); - } - - if (UseMultipleStreams()) { - auto event = std::make_shared(stream_->parent()); - TF_RET_CHECK(event->Init()) << "Event failed to initialize"; - device_to_device_stream->ThenRecordEvent(event.get()); - xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); - } - return Status::OK(); - }(); - if (!status.ok()) { - return done(status); - } else { - stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback to avoid - // a deadlock. If done() is the callback that ends an Executor's run, the - // Executor may call XlaDevice::Sync() inside the callback. This - // deadlocks, because XlaDevice::Sync() waits for all stream activity to - // complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); - } -} - -XlaDeviceContext::XlaDeviceContext( - std::shared_ptr compute_stream, - std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool) - : manager_(std::move(compute_stream), std::move(host_to_device_stream), - std::move(device_to_host_stream), client, transfer_as_literal, - std::move(shape_representation_fn), thread_pool) {} - -void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, - Device* device, - Tensor* device_tensor, - StatusCallback done) const { - manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done); -} + // Transfer manager requires the shape of the shaped buffer to be the same as + // literal shape except for the layout. Set the literal to use xla_tensor's + // shape as it is derived from the cpu_tensor's shape using + // shape_representation_fn_. + xla::MutableBorrowingLiteral literal; + TF_CHECK_OK(HostTensorToMutableBorrowingLiteral( + xla::LayoutUtil::GetWithDefaultLayout( + xla_tensor->shaped_buffer().on_host_shape()), + cpu_tensor, &literal)); -void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, - Device* device, Tensor* cpu_tensor, - StatusCallback done) { - manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, - done); + TensorReference ref(*device_tensor); + transfer_manager_->TransferLiteralFromDevice( + device_to_host_stream_.get(), xla_tensor->shaped_buffer(), literal, + [ref, xla_tensor, done](xla::Status status) { + done([&]() -> Status { + VLOG(1) << "Transfer from device as literal: " + << xla_tensor->shaped_buffer().ToString(); + return status; + }()); + ref.Unref(); + }); } -void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, - Tensor* dst_tensor, - const StatusCallback& done) { - manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() { + DCHECK_GT(device_to_device_streams_.size(), 0); + absl::MutexLock lock(&mu_); + int stream = next_stream_; + next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size(); + return device_to_device_stream(stream); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index df824212948ac96a5df5228cecd9a8c864bbec9a..1e18df197a2dd65590c5181b4dae4481dca36641 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -44,33 +45,44 @@ class XlaDeviceAllocator : public Allocator { }; // Helper class for managing data transfers between host and XLA devices. -class XlaTransferManager { +class XlaDeviceContext : public DeviceContext { public: - explicit XlaTransferManager( + explicit XlaDeviceContext( std::shared_ptr compute_stream, std::shared_ptr host_to_device_stream, std::shared_ptr device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, + std::vector> device_to_device_streams, + xla::LocalClient* client, XlaCompiler::ShapeRepresentationFn shape_representation_fn, thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor, StatusCallback done) const; + Tensor* device_tensor, + StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done); - - void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, - const StatusCallback& done); + Tensor* cpu_tensor, StatusCallback done) override; + xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } + se::Stream* host_to_device_stream() const { + return host_to_device_stream_.get(); + } + se::Stream* device_to_host_stream() const { + return device_to_host_stream_.get(); + } + se::Stream* device_to_device_stream(int index) const { + return device_to_device_streams_.at(index).get(); + } + xla::TransferManager* transfer_manager() const { return transfer_manager_; } + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } + + // Returns a device-to-device stream, in round-robin fashion. + se::Stream* GetDeviceToDeviceStream(); private: - Status TransferLiteralToDevice(const Tensor& host_tensor, - Tensor* device_tensor) const; - void TransferLiteralFromDevice(Tensor* host_tensor, - const Tensor& device_tensor, - const StatusCallback& done) const; bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } // The main compute stream of the device, used to synchronize the transfer @@ -82,44 +94,22 @@ class XlaTransferManager { // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. std::shared_ptr device_to_host_stream_; + // Streams to use for transferring data directly between different devices, + // e.g., over NVLINK. + std::vector> device_to_device_streams_; + // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; - // True if we must use XLA's TransferManager for correct device transfers. - const bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // Thread pool used for running closures thread::ThreadPool* thread_pool_; -}; -// DeviceContext for operators assigned to XlaDevice devices. The -// implementation must inherit from DeviceContext but otherwise just -// wraps the methods in XlaTransferManager. -class XlaDeviceContext : public DeviceContext { - public: - explicit XlaDeviceContext( - std::shared_ptr compute_stream, - std::shared_ptr host_to_device_stream, - std::shared_ptr device_to_host_stream, - xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - thread::ThreadPool* thread_pool); - - void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, - Tensor* device_tensor, - StatusCallback done) const override; - void CopyDeviceTensorToCPU(const Tensor* device_tensor, - absl::string_view tensor_name, Device* device, - Tensor* cpu_tensor, StatusCallback done) override; - void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, - const StatusCallback& done); - - se::Stream* stream() const override { return manager_.stream(); } - - private: - XlaTransferManager manager_; + absl::Mutex mu_; + int next_stream_ GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 5ecb1afa7bcec910ca843ccd3a782745f2bb6ca8..f56c26ba0103fed152322f0c8971a449610cdc2b 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -30,81 +30,43 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { } XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) - : AsyncOpKernel(c) { + : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); } -void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, - DoneCallback done) { - OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), - errors::InvalidArgument( - "Variable and value dtypes don't match; respectively, ", - dtype_, " and ", context->input(1).dtype()), - done); +void XlaAssignVariableOp::Compute(OpKernelContext* context) { + OP_REQUIRES(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + DataTypeString(dtype_), " and ", + DataTypeString(context->input(1).dtype()))); Var* variable = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - LookupOrCreateResource( - context, HandleFromInput(context, 0), &variable, - [this, context](Var** ptr) { - *ptr = new Var(dtype_); - PersistentTensor unused; - Tensor* tmp; - AllocatorAttributes attr; - TF_RETURN_IF_ERROR(context->allocate_persistent( - dtype_, context->input(1).shape(), &unused, &tmp, attr)); - *(*ptr)->tensor() = *tmp; - return Status::OK(); - }), - done); - core::ScopedUnref s(variable); - - OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, - errors::InvalidArgument( - "Trying to assign variable with wrong dtype. Expected ", - DataTypeString(variable->tensor()->dtype()), " got ", - DataTypeString(dtype_)), - done); - const Tensor& value = context->input(1); - AllocatorAttributes attr; - - // Copying is unnecessary if we are the last user of the value tensor, we can - // just adopt the input tensor's buffer instead. - std::unique_ptr input_alias = context->forward_input( - 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, - value.shape(), DEVICE_MEMORY, attr); + // Note: every resource-variable-manipulating op assumes copy-on-write + // semantics, and creates a copy of the variable's Tensor if its refcount is + // bigger than 1 when we try to modify it. This means we never need to copy + // the original tensor for AssignVariableOp; even if there are other live + // users of it we know none can modify it so this is always safe (even in + // esoteric cases where the same tensor is used to initialize multiple + // variables or the tensor is a constant this is safe, as future writes will + // trigger copies). + OP_REQUIRES_OK(context, LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, &value](Var** ptr) { + *ptr = new Var(dtype_); + *(*ptr)->tensor() = value; + (*ptr)->is_initialized = true; + return Status::OK(); + })); + core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); + OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_))); variable->is_initialized = true; - if (input_alias) { - *variable->tensor() = *input_alias; - done(); - return; - } - - // Need to copy, but maybe we can re-use variable's buffer? - if (!XlaTensor::RefCountIsOne(*variable->tensor()) || - !variable->tensor()->shape().IsSameSize(value.shape())) { - // Copy to new buffer - PersistentTensor unused; - Tensor* tmp; - OP_REQUIRES_OK_ASYNC(context, - context->allocate_persistent(dtype_, value.shape(), - &unused, &tmp, attr), - done); - *variable->tensor() = *tmp; - } - - XlaDeviceContext* device_context = - static_cast(context->op_device_context()); - - variable->Ref(); - device_context->CopyDeviceTensorToDevice( - value, variable->tensor(), [context, variable, done](Status status) { - variable->Unref(); - context->SetStatus(status); - done(); - }); + *variable->tensor() = value; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 6967ad1f03fb5dd962d5b41f0c7ab1dfa42fab94..927f983ba9ef23c8509523f42366c0c89c29db9f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/stack.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -49,10 +50,10 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; -class XlaAssignVariableOp : public AsyncOpKernel { +class XlaAssignVariableOp : public OpKernel { public: explicit XlaAssignVariableOp(OpKernelConstruction* c); - void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + void Compute(OpKernelContext* context) override; private: DataType dtype_; @@ -65,11 +66,13 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); -#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ - .Device(DEVICE) \ - .HostMemory("constants") \ - .HostMemory("resources"), \ +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("key") \ + .HostMemory("compilation_successful") \ + .HostMemory("resources"), \ KERNEL); #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ @@ -91,6 +94,9 @@ class XlaAssignVariableOp : public AsyncOpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE).TypeConstraint("T", DT_STRING), \ + IdentityOp); \ REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ @@ -197,6 +203,8 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("output") \ .TypeConstraint("T"), \ ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ \ REGISTER_KERNEL_BUILDER(Name(kRetOp) \ .Device(DEVICE) \ @@ -208,6 +216,8 @@ class XlaAssignVariableOp : public AsyncOpKernel { .TypeConstraint("T") \ .HostMemory("input"), \ RetvalOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kDeviceRetOp).Device(DEVICE).TypeConstraint("T"), RetvalOp); \ \ REGISTER_KERNEL_BUILDER( \ Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ @@ -250,9 +260,27 @@ class XlaAssignVariableOp : public AsyncOpKernel { .Device(DEVICE) \ .TypeConstraint("T") \ .HostMemory("input"), \ - RetvalOp); + RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("StackV2") \ + .Device(DEVICE) \ + .HostMemory("max_size") \ + .HostMemory("handle"), \ + StackOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("T", TYPES), \ + TemplatedStackPushOp); \ + REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ + .Device(DEVICE) \ + .HostMemory("handle") \ + .TypeConstraint("elem_type", TYPES), \ + StackPopOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp); -// TODO(phawkins): currently we do not register the QueueEnqueueMany, +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read // and write the tensors they access in order to concatenate them into a batch. // We would need either to call out to an XLA computation to perform the diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 60979556a3245f4a9984cde889835ce31154fe18..0191315a66f4d331e54fadc9dc6a073a05fd67ef 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,6 +16,10 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. +#include +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -28,45 +32,76 @@ namespace tensorflow { class XlaGpuDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; -Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, - const string& name_prefix, - std::vector* devices) { +Status XlaGpuDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration); static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; - std::unique_ptr device; - Status status = - XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, - /*use_multiple_streams=*/false, - /*shape_representation_fn=*/{}, - /*padded_shape_fn=*/{}, &device); - if (!status.ok()) { + auto platform = se::MultiPlatformManager::PlatformWithName("CUDA"); + if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. - VLOG(1) << "Failed to create XLA_GPU device: " << status; + VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); return Status::OK(); } - - // TODO(b/78468222): Uncomment after fixing this bug - // status = device->UseGpuDeviceInfo(); - // if (!status.ok()) { - // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, - // " device"); - // return status; - // } - - devices->push_back(device.release()); + string allowed_gpus = + session_options.config.gpu_options().visible_device_list(); + std::set gpu_ids; + int num_visible_devices = platform.ValueOrDie()->VisibleDeviceCount(); + if (allowed_gpus.empty()) { + for (int i = 0; i < num_visible_devices; ++i) { + gpu_ids.insert(i); + } + } else { + // For loop below is copied from gpu/gpu_device.cc. It validates + // the visible_device_list and populates gpu_ids set. + const std::vector visible_devices = + absl::StrSplit(allowed_gpus, ','); + for (const string& platform_gpu_id_str : visible_devices) { + int32 platform_gpu_id; + if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + return errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_gpu_id_str, "'. visible_device_list = ", allowed_gpus); + } + if (platform_gpu_id < 0 || platform_gpu_id >= num_visible_devices) { + return errors::InvalidArgument( + "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, + "' but visible device count is ", num_visible_devices); + } + gpu_ids.insert(platform_gpu_id); + } + } + for (int i : gpu_ids) { + XlaDevice::Options options; + options.platform = platform.ValueOrDie(); + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_GPU; + options.device_ordinal = i; + options.compilation_device_name = DEVICE_GPU_XLA_JIT; + options.use_multiple_streams = true; + auto device = absl::make_unique(session_options, options); + + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + " device number ", i); + return status; + } + + devices->push_back(std::move(device)); + } return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 19e681af0c940023de2ce82b3b337babe2f3dd5a..4007309ed1c57b663dca5bac0df11260bf1327f3 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -15,6 +15,7 @@ limitations under the License. // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" @@ -25,37 +26,43 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kExecAllTypes = { + {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_BOOL, DT_BFLOAT16}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector* devices) override; + std::vector>* devices) override; }; Status XlaInterpreterDeviceFactory::CreateDevices( - const SessionOptions& options, const string& name_prefix, - std::vector* devices) { + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; XlaOpRegistry::DeviceRegistration registration; registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kAlways; registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, + registration); + + TF_ASSIGN_OR_RETURN( + auto platform, se::MultiPlatformManager::PlatformWithName("Interpreter")); + + XlaDevice::Options options; + options.platform = platform; + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_XLA_INTERPRETER; + options.device_ordinal = 0; + options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; + options.use_multiple_streams = false; + devices->push_back(absl::make_unique(session_options, options)); - std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, - DEVICE_INTERPRETER_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, - /*use_multiple_streams=*/false, - /*shape_representation_fn=*/{}, - /*padded_shape_fn=*/{}, &device)); - devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 4f6fc4e068e3ba125ddbca264c1affa1f09f5896..3b0bda4caa161a7561a3098b89420329998ff8a7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -41,22 +42,127 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map SnapshotResourceVariables( - OpKernelContext* ctx, absl::Span variables) { - std::map snapshot; - for (int i : variables) { - Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, i); - OptionalTensor& tensor = snapshot[i]; - if (LookupResource(ctx, handle, &variable).ok()) { - core::ScopedUnref scoped_unref(variable); - tf_shared_lock lock(*variable->mu()); - tensor.name = handle.name(); +VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {} +VariableInfo::VariableInfo(VariableInfo&& other) + : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) { + other.index_ = -1; + other.var_ = nullptr; +} + +VariableInfo& VariableInfo::operator=(VariableInfo&& other) { + index_ = other.index_; + var_ = other.var_; + lock_held_ = other.lock_held_; + + other.index_ = -1; + other.var_ = nullptr; + + return *this; +} + +VariableInfo::~VariableInfo() { + // Release the variable's lock if we hold it. Ensures that the lock is + // released even on error. It does not matter in what order we release the + // locks. + if (var()) { + if (lock_held()) { + var()->mu()->unlock(); + } + + // Unref the variable so it can be released by ResourceManager. + var()->Unref(); + } +} + +// Returns a vector of VaribleInfo instances for the resource variable inputs to +// the kernel with context `ctx`. The input indices for the resource variable +// inputs are in `variable_indices`. +static Status GetVariableInfosFromCtxInputs( + OpKernelContext* ctx, absl::Span variable_indices, + std::vector* result) { + std::vector resource_handles; + absl::c_transform( + variable_indices, std::back_inserter(resource_handles), + [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); + + std::vector> variables; + TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables)); + + result->clear(); + result->reserve(variable_indices.size()); + for (int i = 0; i < variable_indices.size(); i++) { + // *Release* the variable because we're going to unref it later in + // ~VariableInfo. + Var* variable = variables[i].release(); + result->emplace_back(variable_indices[i], variable); + } + + return Status::OK(); +} + +Status LockVariables(absl::Span variables) { + std::vector lock_order(variables.size()); + std::iota(lock_order.begin(), lock_order.end(), 0); + + // VariableInfoComparator orders all empty VariableInfo instances as + // equivalent so it looks like we may want to stable sort these to maintain a + // deterministic order between the empty VariableInfo instances. However + // since we're sorting by pointer value the sort is pretty non-deterministic + // anyway so we don't bother using std::stable_sort for now. + absl::c_sort(lock_order, [&](int a, int b) { + if (variables[a].var() && variables[b].var()) { + return variables[a].var()->mu() < variables[b].var()->mu(); + } + + // Move all the empty VariableInfo instances to the end. + return variables[a].var() != nullptr; + }); + + mutex* prev = nullptr; + for (int i : lock_order) { + Var* variable = variables[i].var(); + if (variable == nullptr) { + // All empty VariableInfo instances are at the end of the order + // so we're done. + break; + } + mutex* mu = variable->mu(); + if (prev == mu) { + // It is an error to pass the same variable handle twice to the same XLA + // cluster because we would not handle variable updates correctly. Any + // locks we have already acquired will be released when the VariableInfo + // objects are destroyed. + return errors::Internal("Duplicate variable passed to XLA cluster"); + } + VLOG(4) << "Acquiring lock for variable " + << reinterpret_cast(variable); + mu->lock(); + variables[i].set_lock_held(); + prev = mu; + } + VLOG(4) << "Finished acquiring variable locks."; + return Status::OK(); +} + +Status SnapshotResourceVariables(OpKernelContext* ctx, + absl::Span variable_indices, + std::map* result) { + std::vector variable_infos; + TF_RETURN_IF_ERROR( + GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + + for (int i = 0; i < variable_indices.size(); i++) { + if (variable_infos[i].var()) { + OptionalTensor& tensor = (*result)[variable_indices[i]]; + tensor.name = HandleFromInput(ctx, variable_indices[i]).name(); tensor.present = true; - tensor.value = *variable->tensor(); + tensor.value = *variable_infos[i].var()->tensor(); + } else { + (*result)[variable_indices[i]] = OptionalTensor(); } } - return snapshot; + return Status::OK(); } XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) @@ -85,40 +191,6 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { return Status::OK(); } -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -ScopedShapedBuffer ExtractSubShapedBuffer( - ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator) { - const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_host_shape(), index); - const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( - shaped_buffer->on_device_shape(), index); - - ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, - shaped_buffer->platform(), - shaped_buffer->device_ordinal()); - - auto& shape_tree = shaped_buffer->buffers(); - auto& sub_shape_tree = sub_shaped_buffer.buffers(); - sub_shape_tree.CopySubtreeFrom(shape_tree, - /*source_base_index=*/{index}, - /*target_base_index=*/{}); - shape_tree.ForEachMutableElement( - [index](const xla::ShapeIndex& shape_index, - tensorflow::se::DeviceMemoryBase* data) { - // shape_index is empty for the root node. Ignore that. - if (!shape_index.empty() && shape_index[0] == index) { - *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); - } - }); - return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); -} -} // namespace internal -using internal::ExtractSubShapedBuffer; - XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors, bool use_multiple_streams) @@ -160,10 +232,7 @@ void XlaComputationLaunchContext::PopulateInputs( CHECK(stream) << "Must have a stream available when using XLA tensors!"; XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor); - if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) { - stream->ThenWaitFor(event); - xla_tensor->SetDefinedOn(stream); - } + xla_tensor->WaitForDefinitionEventOnStream(stream); } const xla::Shape on_device_shape = @@ -288,10 +357,9 @@ Status XlaComputationLaunchContext::PopulateOutputs( TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + xla_tensor->ResetDefinitionEvent(definition_event, stream); } } else { // xla_tensor wasn't valid, which must mean this is a zero-element @@ -315,30 +383,35 @@ Status XlaComputationLaunchContext::PopulateOutputs( // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; + std::vector variable_infos; + variable_infos.reserve(kernel->resource_updates.size()); + for (int i = 0; i < kernel->resource_updates.size(); ++i) { - Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; int actual_input_index = write.input_index - missing_ctx_input_prefix; if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } - se::DeviceMemoryBase buffer = output.buffer({output_num}); - - Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. + Var* variable = nullptr; TF_RETURN_IF_ERROR(LookupOrCreateResource( ctx, HandleFromInput(ctx, actual_input_index), &variable, [&write](Var** ptr) { *ptr = new Var(write.type); return Status::OK(); })); + variable_infos.emplace_back(actual_input_index, variable); + } - core::ScopedUnref s(variable); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); - mutex_lock ml(*variable->mu()); - if (variable->tensor()->dtype() != write.type) { + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; + + if (variable_infos[i].var()->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); } @@ -346,23 +419,81 @@ Status XlaComputationLaunchContext::PopulateOutputs( Tensor output_tensor; TF_RETURN_IF_ERROR( ctx->allocate_temp(write.type, write.shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); - if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + if (write.shape.num_elements() > 0) { + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); + if (use_multiple_streams_) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); + } } - *variable->tensor() = output_tensor; + *variable_infos[i].var()->tensor() = output_tensor; } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - *variable->tensor() = output_tensor; + *variable_infos[i].var()->tensor() = output_tensor; } ++output_num; } return Status::OK(); } +Status XlaComputationLaunchContext::BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args) { + args->resize(ctx->num_inputs()); + + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { + XlaCompiler::Argument& arg = (*args)[input_num]; + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); + arg.constant_value = input; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); + } else { + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 326d70a027564343408df356833c97e131495da0..554227f09de0ab4d9e07f199b957657f3121ff06 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ +#include "absl/base/thread_annotations.h" #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -34,17 +35,75 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, whose -// indices are specified in `variables` argument. We snapshot tensors that back +// Struct that represents a possibly-absent Tensor. +struct OptionalTensor { + string name; // A descriptive name + bool present = false; // Is the tensor present? + Tensor value; // If present, what is the Tensor's value? +}; + +// Takes a snapshot of the values of resource variable arguments, whose indices +// are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // +// We snapshot the entire set of resource variables as one atomic operation. +// This models Read->* dependencies between resource variable operations. See +// jit/resource_operation_safety_analysis for details. +// // Returns a map of TensorFlow argument index to resource variable. If a // resource variable is not initialized, the corresponding OptionalTensor // will have its `present` field set to false. -std::map SnapshotResourceVariables( - OpKernelContext* ctx, absl::Span variables); +Status SnapshotResourceVariables(OpKernelContext* ctx, + absl::Span variable_indices, + std::map* result); + +// Information about the state of a variable passed as input to the _XlaCompile +// and _XlaRun operators. Unlocks the resource variable and decrements its +// refcount on destruction. +class VariableInfo { + public: + explicit VariableInfo(int index, Var* var); + VariableInfo(VariableInfo&& other); + + VariableInfo& operator=(VariableInfo&& other); + + VariableInfo(const VariableInfo&) = delete; + VariableInfo& operator=(const VariableInfo&) = delete; + + // The index of the DT_RESOURCE input to the _XlaCompile/_XlaRun operator. + // Note that the indices can be different between _XlaCompile and _XlaRun. + int index() const { return index_; } + + // A pointer to the resource variable. May be null if this VariableInfo is + // "empty", i.e. it does not track a resource variable. + Var* var() const { return var_; } + + // Returns true if the resource variable lock was successfully acquired by + // this thread. + bool lock_held() const { return lock_held_; } + void set_lock_held() { lock_held_ = true; } + + ~VariableInfo(); + + private: + int index_; + Var* var_; + + // We can't use a optional here because it confuses the compiler's + // thread safety analysis. Instead we use a boolean flag and release the lock + // in the VariableInfo destructor. + bool lock_held_ = false; +}; + +// Acquires the mutexes for all the variables in `variables` using a +// deadlock-safe protocol (acquire the mutexes in increasing-address order). +// +// `variables` is allowed to contain instances that don't track a resource +// variable (i.e. variables[i].var() can be null for some i). +Status LockVariables(absl::Span variables) + EXCLUSIVE_LOCK_FUNCTION(); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -87,6 +146,13 @@ class XlaComputationLaunchContext { bool allocate_xla_tensors, bool use_multiple_streams); + // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch + // op. + static Status BuildXlaCompilerArguments( + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + std::vector* args); + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. // @@ -99,7 +165,13 @@ class XlaComputationLaunchContext { const std::map& variables, int missing_ctx_input_prefix); - // Given the XLA output in `output`, populate all outputs of `ctx`. + // Given the XLA output in `output`, populate all outputs of `ctx`. Also + // writes out the resource variable updates. + // + // Updates to all resource variables are written in a single atomic operation. + // This models *->Write dependencies between resource variable operations. + // See jit/resource_operation_safety_analysis for details. + // // // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are // missing and adjusts input indices accordingly. @@ -127,19 +199,17 @@ class XlaTensorBuffer : public TensorBuffer { public: XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, Allocator* allocator) - : expected_size_(expected_size), + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), actual_size_(actual_size), - allocator_(allocator) { - data_ = const_cast(ptr); - } + allocator_(allocator) {} ~XlaTensorBuffer() override { - if (data_) { - allocator_->DeallocateRaw(data_); + if (data()) { + allocator_->DeallocateRaw(data()); } } - void* data() const override { return data_; } size_t size() const override { return expected_size_; } TensorBuffer* root_buffer() override { return this; } @@ -159,23 +229,11 @@ class XlaTensorBuffer : public TensorBuffer { } private: - void* data_; size_t expected_size_; size_t actual_size_; Allocator* allocator_; }; -// Exposed in this header file for microbenchmarking purposes, but this is an -// internal implementation detail. -namespace internal { -// Return the 'index''th subtree of the given ShapedBuffer as a -// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the -// subtree, and sets the input's buffer pointers to nullptr for the subtree. -xla::ScopedShapedBuffer ExtractSubShapedBuffer( - xla::ShapedBuffer* shaped_buffer, int index, - xla::DeviceMemoryAllocator* allocator); -} // namespace internal - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc deleted file mode 100644 index a45932403ec1760d6b985d5357fd6d84fbf257a2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Contains microbenchmarks for performance critical functions in -// xla_launch_util.cc. - -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs -// (cardinality of each non-leaf node's children). -void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { - tensorflow::testing::StopTiming(); - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); - for (int i = 0; i < depth; ++i) { - std::vector shapes(fan_out, shape); - shape = xla::ShapeUtil::MakeTupleShape(shapes); - } - xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, - /*device_ordinal=*/0); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - // Extract a buffer from approximately the middle of the first level of the - // tree. - (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, - /*index=*/fan_out / 2, - /*allocator=*/nullptr) - .release(); - } -} - -BENCHMARK(BM_ExtractSubBuffer) - ->ArgPair(1, 4) - ->ArgPair(1, 8) - ->ArgPair(1, 32) - ->ArgPair(1, 64) - ->ArgPair(1, 128) - ->ArgPair(1, 256) - ->ArgPair(1, 512) - ->ArgPair(2, 4) - ->ArgPair(2, 8) - ->ArgPair(2, 32) - ->ArgPair(2, 64) - ->ArgPair(2, 128); - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - tensorflow::testing::RunBenchmarks(); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 92ba7de1b7d32fcf693cd12a380d7a1e0d861d71..d1f7f754c8338487557eda512c56be34c9e958b7 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -43,11 +43,10 @@ namespace tensorflow { } } -Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, +Status XlaTensor::AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal) { - xla::Shape on_host_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); xla::Shape on_device_shape = client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); @@ -73,10 +72,10 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, return Status::OK(); } -se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { +void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) { mutex_lock lock(mu_); if (!definition_event_) { - return nullptr; + return; } // The set of defined streams is expected to be very small indeed (usually @@ -84,24 +83,20 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(), stream) != streams_defined_on_.end()) { // stream is in streams_defined_on_; it doesn't need to be waited on. - return nullptr; + return; } - return definition_event_.get(); + stream->ThenWaitFor(definition_event_.get()); + streams_defined_on_.push_back(stream); } -void XlaTensor::SetDefinedOn(se::Stream* stream, - std::shared_ptr event) { +void XlaTensor::ResetDefinitionEvent(std::shared_ptr event, + se::Stream* stream) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; } -void XlaTensor::SetDefinedOn(se::Stream* stream) { - mutex_lock lock(mu_); - streams_defined_on_.push_back(stream); -} - // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from // device-side tensors, which are either CPU or GPU memory pointers. This works // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index d95da63405889dfd0c279b17789a2195072c7277..77e80aa2527ecc2221ac61f7b7e6ebcce0982931 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -50,7 +50,7 @@ class XlaTensor { // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // is replaced and the managed memory deallocated. - Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_host_shape, xla::LocalClient* client, int device_ordinal); // Some Tensors can have complex on-device shapes, including tuple shapes. To @@ -88,23 +88,19 @@ class XlaTensor { host_tensor_.reset(new Tensor(tensor)); } - // If the tensor's content is not yet defined on 'stream', and there exists an - // se::Event declaring when the tensor's content is defined, return it. - // Otherwise, return nullptr. If this function returns nullptr then the - // tensor's content can be read on 'stream' without additional - // synchronization. - se::Event* GetDefinitionEvent(se::Stream* stream); - - // Assert that the tensor's content is defined on 'stream' by the time 'event' - // triggers. - void SetDefinedOn(se::Stream* stream, std::shared_ptr event); - - // Assert that the tensor's content is defined on 'stream'. This version does - // not provide an event, and must be called *after* SetDefinedOn(Stream, - // Event). This call can be read as an assertion that the definition event has - // been waited on by 'stream', so further calls to GetDefinitionEvent(stream) - // do not need to also wait on the event. - void SetDefinedOn(se::Stream* stream); + // Adds synchronization events to 'stream' that wait for this tensor to be + // defined on 'stream'. Does nothing if the tensor is already defined on that + // stream. + void WaitForDefinitionEventOnStream(se::Stream* stream); + + // (Re)sets the definition event of the tensor to 'event', and promises that + // the tensor has already been defined on stream. Removes any previous + // definition event or any previous promises about the tensor being defined on + // streams. + // It is legal to reset the definition event of a tensor when overwriting the + // tensor's value (at which point, it is effectively a new tensor once again.) + void ResetDefinitionEvent(std::shared_ptr event, + se::Stream* stream); // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); diff --git a/tensorflow/compiler/plugin/README.md b/tensorflow/compiler/plugin/README.md index 9dd0d2bdab5e2c990fd547cef4b657253c545715..07465934aec0364eb03ddfb7f99ea54aaf084fff 100644 --- a/tensorflow/compiler/plugin/README.md +++ b/tensorflow/compiler/plugin/README.md @@ -1,5 +1,4 @@ -3rd party XLA devices ---------------------- +## 3rd party XLA devices This directory is intended as a place for 3rd party XLA devices which are _not_ integrated into the public repository. @@ -9,8 +8,5 @@ can be included as a dependency of the JIT subsystem. For integration into the unit test system, see the files: -- tensorflow/compiler/tests/plugin.bzl -- tensorflow/compiler/xla/tests/plugin.bzl - - -- +- tensorflow/compiler/tests/plugin.bzl +- tensorflow/compiler/xla/tests/plugin.bzl diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index ba2401ed2628beeba2be3bf59a067c3d87ca3f9f..093b61629cd0b04d5d8488139b8d7262b739f86d 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -294,33 +294,6 @@ tf_xla_py_test( ], ) -tf_xla_py_test( - name = "oom_test", - size = "medium", - srcs = ["oom_test.py"], - # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. - disabled_backends = [ - "cpu", - "cpu_ondemand", - "gpu", - ], - tags = [ - # Allocates very large amounts of memory and does not work under TSAN. - "notsan", - "optonly", # Times out frequently in fastbuild. - ], - deps = [ - ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework", - "//tensorflow/python:gradient_checker", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - tf_xla_py_test( name = "conv2d_test", size = "medium", @@ -435,13 +408,6 @@ tf_xla_py_test( name = "eager_test", size = "large", srcs = ["eager_test.py"], - disabled_backends = [ - # TODO(b/78199195) Support XLA CPU devices in eager runtime - "cpu", - "cpu_ondemand", - # TODO(b/78468222) Enable GPU backend - "gpu", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -476,12 +442,11 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:framework", "//tensorflow/python:platform_test", - "//tensorflow/python:spectral_ops", + "//tensorflow/python/ops/signal", ], ) @@ -516,8 +481,6 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], - # Functions are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -707,9 +670,6 @@ tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - disabled_backends = [ - "cpu_ondemand", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -740,7 +700,6 @@ tf_xla_py_test( name = "reduce_window_test", size = "small", srcs = ["reduce_window_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -849,8 +808,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - # Stack ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -867,9 +824,9 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/contrib/stateless", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "//tensorflow/python:stateless_random_ops", ], ) @@ -878,7 +835,7 @@ tf_xla_py_test( size = "small", srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", + disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -899,7 +856,7 @@ tf_xla_py_test( size = "small", srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. - disabled_backends = "cpu_ondemand", + disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -979,7 +936,6 @@ tf_xla_py_test( name = "while_test", size = "small", srcs = ["while_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", @@ -1089,6 +1045,7 @@ cuda_py_test( size = "medium", srcs = ["jit_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1107,6 +1064,7 @@ cuda_py_test( size = "small", srcs = ["dense_layer_test.py"], additional_deps = [ + ":test_utils", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -1134,6 +1092,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], @@ -1244,7 +1203,6 @@ tf_xla_py_test( name = "xla_ops_test", size = "medium", srcs = ["xla_ops_test.py"], - disabled_backends = ["cpu_ondemand"], deps = [ ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index 69fb3ec2964a09508e612515b9e291fc14121d68..e9c2d363acab96c0fb968cb7f901ce105ea8703e 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() @@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # similarly for others. self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: @@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.904534, -1.603567]), var0.eval()) + np.array([-0.904534, -1.603567]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.094821, -0.189358]), var1.eval()) + np.array([-0.094821, -0.189358]), self.evaluate(var1)) def testAdagradDAWithL1(self): for dtype in self.float_types: @@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.895489, -1.59555]), var0.eval()) + np.array([-0.895489, -1.59555]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.085339, -0.17989]), var1.eval()) + np.array([-0.085339, -0.17989]), self.evaluate(var1)) def testAdagradDAWithL1_L2(self): for dtype in self.float_types: @@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1]), global_step=global_step) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run a step of AdagradDA update.run() self.assertAllCloseAccordingToType( - np.array([-0.046907, -0.093659]), var0.eval()) + np.array([-0.046907, -0.093659]), self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([-0.004275, -0.009023]), var1.eval()) + np.array([-0.004275, -0.009023]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index ab69319c59fb07e7ce56c3c287a50a6290effdfd..e26483303c3934fd51675cb1fbc998b276caf527 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testTensorLearningRate(self): @@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase): zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of adagrad for _ in range(3): ada_update.run() # Validate updated params self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) def testSharing(self): @@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values. - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Mix the first and the second adagrad for 3 steps. ada_update1.run() ada_update2.run() ada_update1.run() # Validate updated params (the same as with only 1 Adagrad). self.assertAllCloseAccordingToType( - np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), + np.array([-1.6026098728179932, -0.6026098728179932]), + self.evaluate(var0), float_rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([2.715679168701172, 3.715679168701172]), var1.eval(), + np.array([2.715679168701172, 3.715679168701172]), + self.evaluate(var1), float_rtol=1e-5) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 058576b3d4b695209952158769162bb24e7ccfce..8bcff9d379d34f8a6bb8b0fdc60b7588c6d80be9 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRate(self): for dtype in self.float_types: @@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power, beta2_power = opt._get_beta_accumulators() # Run 3 steps of Adam for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testSharing(self): for dtype in self.float_types: @@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase): beta1_power, beta2_power = opt._get_beta_accumulators() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of intertwined Adam1 and Adam2. for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) - self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**t, + self.evaluate(beta2_power)) if t % 2 == 0: update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) else: @@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase): var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index 3ed1d41b7121f44dd7470f61180f7a7055369174..961b46375c941bdc3922e460a2f58345086dbceb 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() @@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): for t in range(1, 4): update.run() - self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) + self.assertAllCloseAccordingToType( + var0_np, self.evaluate(var0), rtol=1e-2) + self.assertAllCloseAccordingToType( + var1_np, self.evaluate(var1), rtol=1e-2) self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) @@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) beta1_power = opt._get_beta_accumulators() # Run 3 steps of AdaMax for t in range(1, 4): - self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power)) update.run() var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 1bc07ace23ccdc83103abe71ee11b72994c75a6d..a37c97e6d374440aeb860b9d02f2d5dd95c91f62 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of AddSign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - var0_np, var0.eval(), half_rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + var0_np, self.evaluate(var0), half_rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1b39d53dc0908e1fa05f766ca1e601731b26846d..9a5423c1b2a5df7880453cbb328f6a8174066255 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools + import numpy as np from tensorflow.compiler.tests import xla_test @@ -178,6 +180,13 @@ class BinaryOpsTest(xla_test.XLATestCase): [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops.leaky_relu_grad, + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), + expected=np.array([0.2, 0.4, 0.6, 0.8, 1, 6, 7, 8, 9, 10], + dtype=dtype)) + self._testBinary( gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), @@ -209,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + # TF doesn't define these for bf16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_math_ops.xdivy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype)) + + self._testBinary( + gen_math_ops.xlogy, + np.array([0, 4, 3, 2, 1, 0], dtype=dtype), + np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype), + expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0], + dtype=dtype)) + def testIntOps(self): for dtype in self.signed_int_types: self._testBinary( @@ -960,7 +984,7 @@ class BinaryOpsTest(xla_test.XLATestCase): self._testBinary( array_ops.expand_dims, np.array([42], dtype=dtype), - np.int32(0), + np.array([0], dtype=np.int64), expected=np.array([[42]], dtype=dtype)) self._testBinary( array_ops.expand_dims, @@ -987,15 +1011,21 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([[[1, 2], [3, 4]]], dtype=dtype), np.int32(3), expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) + self._testBinary( + array_ops.expand_dims, + np.array([[[1, 2], [3, 4]]], dtype=dtype), + np.array([2], dtype=np.int64), + expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) def testPad(self): - for dtype in self.numeric_types: + for dtype, pad_type in itertools.product( + self.numeric_types, [np.int32, np.int64]): self._testBinary( array_ops.pad, np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[1, 2], [2, 1]], dtype=np.int32), + [[1, 2], [2, 1]], dtype=pad_type), expected=np.array( [[0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0], @@ -1009,7 +1039,7 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array( [[1, 2, 3], [4, 5, 6]], dtype=dtype), np.array( - [[0, 3], [2, 1]], dtype=np.int32), + [[0, 3], [2, 1]], dtype=pad_type), expected=np.array( [[7, 7, 1, 2, 3, 7], [7, 7, 4, 5, 6, 7], diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 1d3979b21bfd915a641fabe1ef40301b3e5a17b4..447a7de2cb6526a5dcf7789d4f2bffb5e733e8c0 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -50,6 +50,8 @@ def tf_xla_py_test( """ if disabled_backends == None: disabled_backends = [] + if type(disabled_backends) != "list": + fail("disabled_backends must be a list of strings", "disabled_backends") enabled_backends = [b for b in all_backends() if b not in disabled_backends] test_names = [] diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a57d1dc81ea2c9c188b0a3005904738aa8156bf3..5d5e486f616937601214aa169a4c329ab78932c8 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import googletest @@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.cached_session() as sess, self.test_scope(): + with self.cached_session(), self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) - d = sess.run(op) + d = self.evaluate(op) batch_size, num_classes = logits.shape freqs_mat = [] @@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, output_dtype=output_dtype) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= 0).sum() == 1000) self.assertTrue((y < 20).sum() == 1000) @@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase): chi2 = self._chi2(probs, freqs) self.assertLess(chi2, 1e-3) + def testStatelessMultinomialIsInRange(self): + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), + 1000, + seed_t, + output_dtype=output_dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testDeterminismMultinomial(self): + # Stateless values should be equal iff the seeds are equal (roughly) + num_samples = 10 + with self.cached_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], + [0.25, 0.75]]): + pure = stateless_random_ops.stateless_multinomial( + logits, num_samples, seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testEmpty(self): + with self.cached_session(): + with self.test_scope(): + x = random_ops.multinomial( + array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32) + y = self.evaluate(x) + self.assertEqual(y.shape, (42, 0)) + + def testEmptyStateless(self): + with self.cached_session() as sess: + with self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless_random_ops.stateless_multinomial( + array_ops.zeros([42, 40]), + 0, + seed=seed_t, + output_dtype=dtypes.int32) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertEqual(y.shape, (42, 0)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index 88bd58b2da6b2892f898ad10f3467d8ce39d6388..ef2d7af69deeebd5f4c4c7225d7027f8f76bf861 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") output = math_ops.add(input1, input2) - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testAddFromCpuMultiple(self): @@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase): with self.test_scope(): output = math_ops.add(input1, input2) for _ in xrange(10): - result = output.eval() + result = self.evaluate(output) self.assertAllClose(result, expected, rtol=1e-3) def testDeadlock(self): diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 2d225ad226cac368042b95eae8fc29e6fd8e82e0..2187f57960f80300d631bdc7eb8fe5e9c8dddeea 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase): x2 = constant_op.constant(p2) with self.test_scope(): c = array_ops.concat([x1, x2], 0) - result = c.eval() + result = self.evaluate(c) self.assertAllEqual(result[:2, :], p1) self.assertAllEqual(result[2:, :], p2) @@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 1) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) def testGradientsSimpleAll(self): @@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 0) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, 2) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase): [float(x) for x in grad_inp.flatten()], shape=output_shape) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) concated_grad = array_ops.concat(grad, concat_dim) - result = concated_grad.eval() + result = self.evaluate(concated_grad) self.assertAllEqual(result, grad_inp) @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(c.eval(), correct) # Check gradients dc = np.random.randn(*c.get_shape().as_list()) - dxs = sess.run(gradients_impl.gradients(c, xs, dc)) + dxs = self.evaluate(gradients_impl.gradients(c, xs, dc)) self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) def testConcatTuple(self): @@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) - self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) + self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t)) def testConcatNoScalars(self): with self.cached_session(): @@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) - ans = sess.run(off) + ans = self.evaluate(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) s2 = constant_op.constant(5, dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.cached_session() as sess: + with self.cached_session(): with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) s2 = constant_op.constant([[]], dtypes.int32) packed = array_ops.stack([s0, s1, s2]) - ans = sess.run(packed) + ans = self.evaluate(packed) self.assertAllEqual(ans, [[[]], [[]], [[]]]) diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 33fd983b5485e503c2fcc96db2dfdecfc41e309f..01cc1b6392845be2418c50d55be97487eb290843 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) # We count the number of cells being added at the locations in the output. # At the center, #cells = kernel_depth * kernel_height * kernel_width @@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="SAME") - value = output.eval() + value = self.evaluate(output) for n in xrange(x_shape[0]): for k in xrange(f_shape[3]): @@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( x, f, y_shape, strides=strides, padding="VALID") - value = output.eval() + value = self.evaluate(output) cache_values = np.zeros(y_shape, dtype=np.float32) @@ -225,7 +225,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape], output, y_shape) print("conv3d_transpose gradient err = %g " % err) - err_tolerance = 0.0005 + err_tolerance = 0.001 self.assertLess(err, err_tolerance) diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 9390870e07d6b5bd90dbc5c04bac0946595dcf7f..bf5ea7b1fb6fb3c774c4db20d059f131990d20d3 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.python.layers import layers @@ -30,7 +31,6 @@ from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope - def GetRunMetadataLabels(run_metadata): """Returns all labels in run_metadata.""" labels = [] @@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) class DenseLayerTest(test.TestCase): @@ -68,13 +68,14 @@ class DenseLayerTest(test.TestCase): config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) - with self.test_session(config=config) as sess: + with self.session(config=config) as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -96,9 +97,10 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -124,9 +126,10 @@ class DenseLayerTest(test.TestCase): with jit_scope(): y = layers.dense(x, 3) - sess.run(variables.initialize_all_variables()) + self.evaluate(variables.initialize_all_variables()) run_metadata = config_pb2.RunMetadata() - sess.run( + test_utils.RunWithWarmup( + sess, y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, run_metadata=run_metadata, options=config_pb2.RunOptions( @@ -138,4 +141,6 @@ class DenseLayerTest(test.TestCase): if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 6ef8a68ca5d35d3d2f78f0cb491e7bb98ff97ac9..174bfa9efbcd7dcb4f895237eb01c17bc4a3a6b4 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -255,7 +255,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): t1, t2, strides=[1, stride, stride, 1], padding=padding) value = sess.run(conv, {t1: x1, t2: x2}) print("value = ", value) - self.assertArrayNear(expected, np.ravel(value), 1e-5) + self.assertArrayNear(expected, np.ravel(value), 1e-4) self.assertShapeEqual(value, conv) def testConv2D2x2Filter(self): diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index 50b04daa6b9f4159a3c4bdeecaf900a5b35a833c..e89cf975f5d889091ce92a35165aef55ee5ad4b0 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase): [idx1, idx2], [val1, val2], expected=np.array([[], [], [], []], np.int32)) + def testEmptyIndex(self): + idx1 = np.array([], dtype=np.int32) + idx2 = np.array([[], []], dtype=np.int32) + val1 = np.ndarray(shape=(0, 9), dtype=np.int32) + val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32) + self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2], + expected=np.ndarray( + shape=(0, 9), dtype=np.int32)) + def testSimple1D(self): val1 = np.array([0, 4, 7], dtype=np.int32) val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 63cee550fde9d9d4314b1541fba191df776a4da2..2af32b537ba53723370faf81aebf308a465718c7 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.cached_session() as sess: + with context.graph_mode(), self.cached_session(): with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) product = three * five - self.assertAllEqual(15, sess.run(product)) + self.assertAllEqual(15, self.evaluate(product)) def testDegenerateSlices(self): with self.test_scope(): diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index b3e13fbaa6b33bdaa1be123be558059e96de282e..0edd0c35aa2d417a3ed24decbaa0b5d62d35bb62 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -24,11 +24,10 @@ import numpy as np import scipy.signal as sps from tensorflow.compiler.tests import xla_test -from tensorflow.contrib.signal.python.ops import spectral_ops as signal from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.signal import signal from tensorflow.python.platform import googletest BATCH_DIMS = (3, 5) @@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase): def testFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, - spectral_ops.fft) + signal.fft) def testFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, - spectral_ops.fft2d) + signal.fft2d) def testFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), - spectral_ops.fft3d) + signal.fft3d) def testIFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, - spectral_ops.ifft) + signal.ifft) def testIFFT2D(self): self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, - spectral_ops.ifft2d) + signal.ifft2d) def testIFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), - spectral_ops.ifft3d) + signal.ifft3d) def testRFFT(self): self._VerifyFftMethod( INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), - lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + lambda x: signal.rfft(x, fft_length=[x.shape[-1].value])) def testRFFT2D(self): def _tf_fn(x): - return spectral_ops.rfft2d( + return signal.rfft2d( x, fft_length=[x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod( @@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase): x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) def _tf_fn(x): - return spectral_ops.rfft3d( + return signal.rfft3d( x, fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testRFFT3DMismatchedSize(self): + + def _to_expected(x): + return np.fft.rfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.rfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + def testIRFFT(self): def _tf_fn(x): - return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), @@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase): def testIRFFT2D(self): def _tf_fn(x): - return spectral_ops.irfft2d( + return signal.irfft2d( x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) self._VerifyFftMethod( @@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase): s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) def _tf_fn(x): - return spectral_ops.irfft3d( + return signal.irfft3d( x, fft_length=[ x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) @@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase): self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + def testIRFFT3DMismatchedSize(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) + + def _tf_fn(x): + return signal.irfft3d( + x, + fft_length=[ + x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2 + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 8c7edfd277c992c35a81dd5f261256a86352254e..91d77d2f791834346f43aecb60d116ddbf2faa6e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase): enqueue_op.run() for i in xrange(len(elems)): - vals = dequeued_t.eval() + vals = self.evaluate(dequeued_t) self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): @@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([], size.get_shape()) enqueue_op.run() - self.assertEqual(1, size.eval()) + self.assertEqual(1, self.evaluate(size)) dequeued_t.op.run() - self.assertEqual(0, size.eval()) + self.assertEqual(0, self.evaluate(size)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index f1b87a5ffb73bed62a80abaa152d335f64d970c5..b078053cdbd6d129645734492d34dd25d28ab3ef 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivAdagradTest_AdagradPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Adagrad for a few steps for _ in range(steps): adagrad_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_FtrlPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run Ftrl for a few steps for _ in range(steps): ftrl_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) @@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run GradientDescent for a few steps for _ in range(steps): sgd_update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testFtrlwithoutRegularization(self): for dtype in self.float_types: @@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-2.60260963, -4.29698515]), - var0.eval(), - float_rtol=1e-5, + self.evaluate(var0), + float_rtol=1e-4, half_rtol=1e-2) self.assertAllCloseAccordingToType( np.array([-0.28432083, -0.56694895]), - var1.eval(), + self.evaluate(var1), float_rtol=1e-5, half_rtol=1e-2) @@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps FTRL for _ in range(3): @@ -167,9 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) + np.array([-2.55607247, -3.98729396]), + self.evaluate(var0), + 1e-5, + 1e-5, + float_rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) + np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5, + 1e-5) def testFtrlWithL1(self): for dtype in self.float_types: @@ -186,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -196,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( np.array([-7.66718769, -10.91273689]), - var0.eval(), + self.evaluate(var0), rtol=1e-4, bfloat16_rtol=1e-1, bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( - np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) + np.array([-0.93460727, -1.86147261]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -218,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -227,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + np.array([-0.24059935, -0.46829352]), + self.evaluate(var0), + rtol=1e-5) self.assertAllCloseAccordingToType( - np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) + np.array([-0.02406147, -0.04830509]), + self.evaluate(var1), + rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -253,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -262,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), + self.evaluate(var0), + rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), + self.evaluate(var1), + rtol=1e-4) def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" @@ -290,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase): update1 = opt1.apply_gradients([(grads1, var1)]) variables.global_variables_initializer().run() - self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) - self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0)) + self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1)) # Run 10 steps FTRL for _ in range(10): @@ -300,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # var0 is experiencing L2 shrinkage so it should be smaller than var1 # in magnitude. - self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all()) accum0 = list(opt0._slots["accum"].values())[0].eval() accum1 = list(opt1._slots["accum"].values())[0].eval() # L2 shrinkage should not change how we update grad accumulator. diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index b1891b918c6584abce9da382088ed0037f5319fb..a61827c2ae44de117abad5b7db5c6bcd78fa171e 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testNestedFunctions(self): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_g = Foo(a, b) - result = sess.run(call_g) + result = self.evaluate(call_g) self.assertAllClose(result, expected, rtol=1e-3) def testFunctionMultipleRetvals(self): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.cached_session() as sess: + with self.cached_session(): @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase): b = constant_op.constant(bval, name="b") with self.test_scope(): call_f = Foo(a, b) - result = sess.run(call_f) + result = self.evaluate(call_f) self.assertAllClose(result, expected, rtol=1e-3) def testCompileTimeConstantsInDefun(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 68fdb5caf4c2a496b5058cdda40ca650484a6e0e..0e2d840418156d825e2d141018e49f42374c8fee 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -26,7 +26,6 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test -from tensorflow.python.compat import compat from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -449,8 +448,8 @@ class ResizeBilinearTest(xla_test.XLATestCase): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + expected=np.array([[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], + dtype=np.float32)) def testAlignCorners1x2To3x2Grad(self): for dtype in self.float_types: @@ -478,8 +477,8 @@ class ResizeBilinearTest(xla_test.XLATestCase): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + expected=np.array([[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], + dtype=np.float32)) def testAlignCorners2x2To3x3Grad(self): self._assertBackwardOpMatchesExpected( @@ -499,8 +498,8 @@ class ResizeBilinearTest(xla_test.XLATestCase): np.array([[7, 13], [22, 4]], dtype=np.float32), input_shape=[3, 3], dtype=dtype, - expected=np.array( - [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + expected=np.array([[7, 0, 13], [0, 0, 0], [22, 0, 4]], + dtype=np.float32)) def testAlignCorners4x4To3x3(self): for dtype in self.float_types: @@ -508,8 +507,8 @@ class ResizeBilinearTest(xla_test.XLATestCase): np.array( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=dtype), [3, 3], - expected=np.array( - [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + expected=np.array([[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], + dtype=np.float32)) def testAlignCorners4x4To3x3Grad(self): for dtype in self.float_types: @@ -517,41 +516,39 @@ class ResizeBilinearTest(xla_test.XLATestCase): np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), input_shape=[4, 4], dtype=dtype, - expected=np.array( - [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], - [7, 4, 4, 9]], - dtype=np.float32)) + expected=np.array([[1, 1, 1, 3], [2, 1.25, 1.25, 3], + [2, 1.25, 1.25, 3], [7, 4, 4, 9]], + dtype=np.float32)) def testAlignCorners3x3To9x9(self): for dtype in self.float_types: self._assertForwardOpMatchesExpected( np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9], expected=np.array( - [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ - 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 - ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ - 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 - ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ - 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 - ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ - 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 - ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], + [1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75], + [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], + [3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25], + [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], + [4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75], + [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], + [6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25], + [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], dtype=np.float32)) def testAlignCorners3x3To9x9Grad(self): for dtype in self.float_types: self._assertBackwardOpMatchesExpected( - np.array( - [[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ - 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 - ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ - 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 - ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ - 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 - ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ - 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 - ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], - dtype=np.float32), + np.array([[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], + [1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75], + [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], + [3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25], + [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], + [4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75], + [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], + [6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25], + [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32), input_shape=[3, 3], dtype=dtype, expected=np.array( @@ -572,12 +569,12 @@ class ResizeBilinearTest(xla_test.XLATestCase): (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, [16, 16], - expected=7 * (np.array( - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], - dtype=np.float32) + np.array( - [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], - [12], [13], [14], [15]], - dtype=np.float32)), + expected=7 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + + np.array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), large_tolerance=True) def testNonAlignCorners3x2To6x4(self): @@ -601,172 +598,230 @@ class ResizeBilinearTest(xla_test.XLATestCase): expected=np.array(expected_data, dtype=dtype), align_corners=False) + def testNonAlignCorners3x2To6x4Batch2(self): + input_data = [[[64, 32], [32, 64], [50, 100]], [[32, 16], [16, 32], + [25, 50]]] + expected_data = [[[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]], + [[32.0, 24.0, 16.0, 16.0], [24.0, 24.0, 24.0, 24.0], + [16.0, 24.0, 32.0, 32.0], [20.5, 30.75, 41.0, 41.0], + [25.0, 37.5, 50.0, 50.0], [25.0, 37.5, 50.0, 50.0]]] -class NonMaxSuppressionTest(xla_test.XLATestCase): + for dtype in self.float_types: + input_image = np.array(input_data, dtype=dtype) + expected = np.array(expected_data, dtype=dtype) + with self.cached_session() as sess, self.test_scope(): + image = array_ops.placeholder(input_image.dtype) + resized = gen_image_ops.resize_bilinear( + image, [6, 4], align_corners=False) + out = sess.run(resized, {image: input_image[:, :, :, np.newaxis]}) + self.assertAllClose(expected[:, :, :, np.newaxis], out) - def testNMS128From1024(self): - with compat.forward_compatibility_horizon(2018, 8, 8): - num_boxes = 1024 - boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") - scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") - max_output_size = 128 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) +class NonMaxSuppressionTest(xla_test.XLATestCase): - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) + def testNMS128From1024(self): + num_boxes = 1024 + boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") + scores_np = np.random.normal(0.5, 0.1, (num_boxes,)).astype("f4") + + max_output_size = 128 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, _) = sess.run(selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - with compat.forward_compatibility_horizon(2018, 8, 8): - # Three boxes are selected based on IOU. - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.0, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - score_threshold: score_threshold_np, - iou_threshold: iou_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 3) - self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + # Three boxes are selected based on IOU. + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.0, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + score_threshold: score_threshold_np, + iou_threshold: iou_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - with compat.forward_compatibility_horizon(2018, 8, 8): - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - max_output_size = 3 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.4, dtype=np.float32) - - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - iou_threshold: iou_threshold_np, - score_threshold: score_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 2) - self.assertAllClose(indices_tf[:num_valid], [3, 0]) + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 2) + self.assertAllClose(indices_tf[:num_valid], [3, 0]) def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. # One is filtered out by max_output_size. - with compat.forward_compatibility_horizon(2018, 8, 8): - boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], - [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] - boxes_np = np.array(boxes_data, dtype=np.float32) - - scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] - scores_np = np.array(scores_data, dtype=np.float32) - max_output_size = 1 - iou_threshold_np = np.array(0.5, dtype=np.float32) - score_threshold_np = np.array(0.4, dtype=np.float32) + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) + + def testSelectFromContinuousOverLap(self): + # Tests that a suppressed box does not itself suppress other boxes. + + boxes_data = [[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4], + [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 3]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.5, 0.4, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 3 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.1, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 3) + self.assertAllClose(indices_tf[:num_valid], [0, 2, 4]) - with self.cached_session() as sess: - boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) - scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) - iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, - iou_threshold_np.shape) - score_threshold = array_ops.placeholder(score_threshold_np.dtype, - score_threshold_np.shape) - with self.test_scope(): - selected_indices = image_ops.non_max_suppression_padded( - boxes=boxes, - scores=scores, - max_output_size=max_output_size, - iou_threshold=iou_threshold, - score_threshold=score_threshold, - pad_to_max_output_size=True) - inputs_feed = { - boxes: boxes_np, - scores: scores_np, - iou_threshold: iou_threshold_np, - score_threshold: score_threshold_np - } - (indices_tf, num_valid) = sess.run( - selected_indices, feed_dict=inputs_feed) - - self.assertEqual(indices_tf.size, max_output_size) - self.assertEqual(num_valid, 1) - self.assertAllClose(indices_tf[:num_valid], [3]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index de68ff0e32cd59e65094c0b7319f8ab213eed4db..dbea9849e217519874352b789588a2af62f1c826 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import numpy as np +from tensorflow.compiler.tests import test_utils from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -36,8 +37,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test -jit_scope = jit.experimental_jit_scope +jit_scope = jit.experimental_jit_scope # Disable rewrites to make sure we don't end up having to update this test # whenever we implement new ones. @@ -74,14 +75,14 @@ def RunMetadataLabels(run_metadata): def InLabels(labels, substr): """Returns true iff one of the labels contains substr.""" - return any([substr in x for x in labels]) + return any(substr in x for x in labels) -def MetadataHasXlaOp(run_metadata): +def MetadataHasXlaRunOp(run_metadata): """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaRun") + return InLabels(RunMetadataLabels(run_metadata), "_XlaRun") class JitLaunchTest(test.TestCase): @@ -108,15 +109,14 @@ class JitLaunchTest(test.TestCase): direct_op = fn(*placeholders) run_metadata = config_pb2.RunMetadata() - compiled = sess.run(compiled_op, - feeds, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + compiled = test_utils.RunWithWarmup( + sess, compiled_op, feeds, + config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE), + run_metadata) print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -137,7 +137,7 @@ class JitLaunchTest(test.TestCase): a = constant_op.constant(100) # pylint: disable=unused-variable call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return - sess.run(call, {}) + test_utils.RunWithWarmup(sess, call, {}) def testAliasing(self): """Regression test for compiled functions that return an aliased buffer. @@ -250,17 +250,21 @@ class JitLaunchTest(test.TestCase): dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) with session_lib.Session() as sess: run_metadata = config_pb2.RunMetadata() - output = sess.run(y, {x: dx, - w: dw, - b: db}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + output = test_utils.RunWithWarmup( + sess, + y, { + x: dx, + w: dw, + b: db + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -272,7 +276,7 @@ class XlaCompilationTest(test.TestCase): def testReshape(self): """Tests an operator with compile-time constant and non-constant inputs.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -284,19 +288,22 @@ class XlaCompilationTest(test.TestCase): # statically known as part of the JIT compilation's input graph. z = array_ops.reshape(x, y) run_metadata = config_pb2.RunMetadata() - out = sess.run(z, - {x: np.array([1, 2, 3, 4, 5, 6], np.float32), - y: [-1, 3]}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + z, { + x: np.array([1, 2, 3, 4, 5, 6], np.float32), + y: [-1, 3] + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): """Tests that JIT computations can ignore formal parameters.""" - with self.test_session(config=NoRewriteSessionConfig()) as sess: + with self.session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.int32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -309,18 +316,22 @@ class XlaCompilationTest(test.TestCase): t = math_ops.add(z, z) run_metadata = config_pb2.RunMetadata() - out = sess.run(t, {x: np.int32(7), - y: np.int32(404)}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + out = test_utils.RunWithWarmup( + sess, + t, { + x: np.int32(7), + y: np.int32(404) + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): """Tests that compilation accepts computations containing loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): c = lambda i, _: math_ops.less(i, 5) @@ -332,13 +343,13 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): """Tests that compilation handles switch operators.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) c = array_ops.placeholder(dtypes.bool) @@ -351,13 +362,17 @@ class XlaCompilationTest(test.TestCase): # deadlock. run_metadata = config_pb2.RunMetadata() - result = session.run(t, {x: np.float32(2), - y: np.float32(4), - c: True}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaOp(run_metadata)) + result = test_utils.RunWithWarmup( + session, + t, { + x: np.float32(2), + y: np.float32(4), + c: True + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assert_(MetadataHasXlaRunOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -379,7 +394,7 @@ class XlaCompilationTest(test.TestCase): inp = array_ops.placeholder(dtypes.float32) out = Entry(inp) - with self.test_session( + with self.session( config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: run_metadata = config_pb2.RunMetadata() val = sess.run(out, @@ -392,7 +407,7 @@ class XlaCompilationTest(test.TestCase): def testLoopDeadlock(self): """Regression test for bug that caused deadlocks in graphs with loops.""" - with self.test_session(config=NoRewriteSessionConfig()) as session: + with self.session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): y = x + 1.0 @@ -425,11 +440,13 @@ class XlaCompilationTest(test.TestCase): cfg.graph_options.optimizer_options.do_function_inlining = True with session_lib.Session(graph=g, config=cfg) as sess: run_metadata = config_pb2.RunMetadata() - dx_val = sess.run(dx, - feed_dict={x: 100.}, - run_metadata=run_metadata, - options=config_pb2.RunOptions( - trace_level=config_pb2.RunOptions.FULL_TRACE)) + dx_val = test_utils.RunWithWarmup( + sess, + dx, + feed_dict={x: 100.}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) self.assertAllClose(dx_val, 0.01) return RunMetadataLabels(run_metadata) @@ -475,7 +492,8 @@ class ElementWiseFusionTest(test.TestCase): a7 = a6 + a2 run_metadata = config_pb2.RunMetadata() - output = sess.run( + output = test_utils.RunWithWarmup( + sess, a7, { a1: arg0, a2: arg1 @@ -509,5 +527,135 @@ class ElementWiseFusionTest(test.TestCase): self.assertAllClose(tf_op, tfef_op, rtol=1e-1) +class LazyCompilationTest(test.TestCase): + + def testLazyCompilation(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # The very first run of the cluster is always compiled (non-lazily). + run_metadata_for_first_run = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10., 19., 77., 100.]}, + run_metadata=run_metadata_for_first_run, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_for_first_run), "_XlaCompile")) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_for_first_run), "_XlaRun")) + + run_metadata_before_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_before_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun")) + + # We compile when we see the same shape a second time. + + run_metadata_after_warmup = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10.]}, + run_metadata=run_metadata_after_warmup, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile")) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun")) + + run_metadata_for_new_shape = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [2., 10., 12.]}, + run_metadata=run_metadata_for_new_shape, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels( + RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile")) + self.assertFalse( + InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) + + def testIsMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Make the cluster go megamorphic by running it with lots of shape + # signatures where the cluster is executed with each signature only a few + # times. Then check that we don't compile the cluster ever again. + + for shape in range(10, 50): + for _ in range(0, 49): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 50): + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue( + InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertFalse(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + + def testIsNotMegamorphic(self): + + @function.Defun(compiled=True) + def CompiledFunction(x): + return math_ops.log(x) + + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: + x = array_ops.placeholder(dtypes.float32) + y = CompiledFunction(x) + + # Run the cluster with lots of shape signatures, but in a way that it + # isn't megamorphic (i.e. each shape signature sees a lot of executions). + # Then check that the cluster has not been marked as megamorphic. + + for shape in range(10, 50): + for _ in range(0, 1000): + sess.run(y, feed_dict={x: [0.] * shape}) + + for _ in range(0, 10): + sess.run(y, feed_dict={x: [0.] * 60}) + + run_metadata = config_pb2.RunMetadata() + sess.run( + y, + feed_dict={x: [0.] * 60}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) + self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) + + if __name__ == "__main__": + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 58622114e4f552fb71db9b040a39b57d7da0037c..0210201fa71a6e790e94667073ab4dba542537a5 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.cached_session() as sess: + with self.cached_session(): x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): out_tensor, idx_tensor = array_ops.listdiff( x_tensor, y_tensor, out_idx=index_dtype) - tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor]) self.assertAllEqual(out, tf_out) self.assertAllEqual(idx, tf_idx) self.assertEqual(1, out_tensor.get_shape().ndims) diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index c6ad67993e8bc196a74c9a328df8c9200c92c575..5dddf6ae4e8c8a3d5e9eb7b2c62298df02a0093c 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase): with self.test_scope(): actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, depth_radius, bias, alpha, beta) - expected_val = expected.eval() - actual_val = actual.eval() + expected_val = self.evaluate(expected) + actual_val = self.evaluate(actual) self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 265c0b6d1412de7be3a5bf5e79129cb330ceb162..776ed899e68ddd3893b8bb30b7c8034297aa6515 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -88,8 +88,8 @@ class LSTMTest(test.TestCase): (basename, m_prev_scalar, c_prev_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM step. - sess.run(variables.global_variables_initializer()) - return sess.run([m, c]) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate([m, c]) def testLSTMCell(self): # Run with all-0 weights, no padding. @@ -173,8 +173,8 @@ class LSTMTest(test.TestCase): (basename, m_init_scalar, c_init_scalar, pad_scalar)) # Initialize variables and run the unrolled LSTM layer. - sess.run(variables.global_variables_initializer()) - return sess.run(out_seq) + self.evaluate(variables.global_variables_initializer()) + return self.evaluate(out_seq) def testLSTMLayer(self): # Run with all-0 weights, no padding. diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index f77521a7c49dba39849869ddceb7c0e885147722..3416f7dbd6bdd264bf79785084f981f5b07cb8a9 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) def testNesterovMomentum(self): for dtype in self.float_types: @@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase): var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) var1_np, accum1_np = self._update_nesterov_momentum_numpy( var1_np, accum1_np, 0.9, 0.1, 0.9) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: @@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase): self.assertFalse(slot1 in variables.trainable_variables()) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Step 1: the momentum accumulators where 0. So we should see a normal # update: v -= grad * learning_rate mom_update.run() # Check that the momentum accumulators have been updated. - self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) - self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) + self.assertAllCloseAccordingToType( + np.array([0.1, 0.1]), self.evaluate(slot0)) + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01]), self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( - np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) + np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), + self.evaluate(var0)) self.assertAllCloseAccordingToType( - np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) + np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), + self.evaluate(var1)) # Step 2: the momentum accumulators contain the previous update. mom_update.run() # Check that the momentum accumulators have been updated. self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) + np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), + self.evaluate(slot0)) self.assertAllCloseAccordingToType( - np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) + np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), + self.evaluate(slot1)) # Check that the parameters have been updated. self.assertAllCloseAccordingToType( np.array([ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) - ]), var0.eval()) + ]), self.evaluate(var0)) self.assertAllCloseAccordingToType( np.array([ - 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( - (0.9 * 0.01 + 0.01) * 2.0) - ]), var1.eval()) + 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), + 3.98 - ((0.9 * 0.01 + 0.01) * 2.0) + ]), self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py deleted file mode 100644 index 7635f89249b7b71e5353e0b7cb1cea5c1f7bca1d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tests/oom_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional tests for out-of-memory conditions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.platform import googletest - - -class OutOfMemoryTest(xla_test.XLATestCase): - - def testOutputOutOfMemory(self): - """Allocates tensors until out of memory. - - Generates a large rank-1 tensor. The tensor is an output of an XLA - computation, not constant. - - Check that a ResourceExhaustedError is raised and can be caught. - - We spin in a loop generating larger and larger tensors until an OOM event - happens. We may be running sandboxed, so have a small host memory limit, so - any hardcoded value is unlikely to land in the sweet spot between device - memory size and host memory size with stability. - """ - - def test_loop(): - size = int(2e8) - while True: - with self.cached_session(): - # Force the compiled code to not be constant by feeding in a - # parameter. - p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) - with self.test_scope(): - # Create a computation that produces a large R1 tensor as an - # intermediate result. Reduce it down so that if this file was - # compiled without --config=cuda, we don't force a D2H copy of a - # large tensor and potentially OOM the host. - # - # This is a bit tricky because XLA:GPU doesn't currently support RNG - # ops. Here we rely on the fact that XLA doesn't do algebraic - # simplifications on conv(, ). - c = math_ops.reduce_sum( - nn_ops.convolution( - array_ops.ones([1, size, 1]), - p, - padding='SAME', - data_format='NWC')) - - c.eval(feed_dict={p: [[[1.0]], [[2.0]]]}) - size *= 2 - - self.assertRaises(errors.ResourceExhaustedError, test_loop) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py index dbb9274df4f579fbc6076bf55c9307e4d1cb7768..e2f6de821b5fd4709d305bcd17ee6ba40b1443fd 100644 --- a/tensorflow/compiler/tests/permute_test.py +++ b/tensorflow/compiler/tests/permute_test.py @@ -40,40 +40,48 @@ class XlaPermuteOpTest(xla_test.XLATestCase): self.assertAllEqual(result, expected) def testNHWCToNCHW(self): - x = np.array([7, 4, 9, 3], dtype=np.int32) - self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) def testNCHWToNHWC(self): - x = np.array([7, 4, 9, 3], dtype=np.int32) - self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) def testNHWCToHWNC(self): - x = np.array([7, 4, 9, 3], dtype=np.int32) - self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) def testHWNCToNHWC(self): - x = np.array([7, 4, 9, 3], dtype=np.int32) - self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) + for dtype in {np.int32, np.int64}: + x = np.array([7, 4, 9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) def testNHWCToNCHW2D(self): - x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) - self._runPermuteAndCompare(x, "NHWC", "NCHW", - [[7, 4], [5, 1], [9, 3], [4, 5]]) + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", + [[7, 4], [5, 1], [9, 3], [4, 5]]) def testNHWCToHWNC2D(self): - x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) - self._runPermuteAndCompare(x, "NHWC", "HWNC", - [[9, 3], [4, 5], [7, 4], [5, 1]]) + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "HWNC", + [[9, 3], [4, 5], [7, 4], [5, 1]]) def testHWNCToNHWC2D(self): - x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) - self._runPermuteAndCompare(x, "HWNC", "NHWC", - [[4, 5], [7, 4], [9, 3], [5, 1]]) + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "HWNC", "NHWC", + [[4, 5], [7, 4], [9, 3], [5, 1]]) def testNCHWToNHWC2D(self): - x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) - self._runPermuteAndCompare(x, "NCHW", "NHWC", - [[7, 4], [4, 5], [5, 1], [9, 3]]) + for dtype in {np.int32, np.int64}: + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", + [[7, 4], [4, 5], [5, 1], [9, 3]]) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index 77bb839409f0c323ff6ed2c8d6bd105d3003b398..9671ae0ae973ff82d22744a1feb9b4293d94bbdd 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -33,7 +33,7 @@ class PlaceholderTest(xla_test.XLATestCase): ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 sess.run(variables.variables_initializer([v])) - self.assertEqual(8.0, sess.run(out)) + self.assertEqual(8.0, self.evaluate(out)) def test_placeholder_with_default_fed(self): with self.cached_session() as sess, self.test_scope(): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 86536da7fed0e2309beb32fee9c7c605491592ed..5b35c20027700b34500a31e174061d7087094b61 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -91,8 +91,8 @@ class PowerSignTest(xla_test.XLATestCase): variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 7 steps of powersign # first 4 steps with positive gradient @@ -125,8 +125,8 @@ class PowerSignTest(xla_test.XLATestCase): ) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testDense(self): decay_steps = 10 diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index c41b4171e26af4f7ad0237d7407a5b3691299595..63cc51a470164915b2614a06d18ca1850bb64a3c 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -45,15 +45,17 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval()) - self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval()) + self.assertAllClose( + np.array([-2.60260963, -4.29698515]), self.evaluate(var0)) + self.assertAllClose( + np.array([-0.28432083, -0.56694895]), self.evaluate(var1)) opt_vars = opt.variables() self.assertStartsWith(opt_vars[0].name, var0._shared_name) self.assertStartsWith(opt_vars[1].name, var1._shared_name) @@ -74,14 +76,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Adagrad. for _ in range(3): update.run() - self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval()) - self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) + self.assertAllClose(np.array([-1.60261, -2.296985]), self.evaluate(var0)) + self.assertAllClose(np.array([3.715679, 2.433051]), self.evaluate(var1)) def testProximalAdagradWithL1(self): with self.cached_session(), self.test_scope(): @@ -98,14 +100,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad for _ in range(10): update.run() - self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval()) - self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) + self.assertAllClose(np.array([-6.663634, -9.190331]), self.evaluate(var0)) + self.assertAllClose(np.array([2.959304, 1.029232]), self.evaluate(var1)) def testProximalAdagradWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -122,15 +124,15 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Adagrad. for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -141,14 +143,14 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivAdagradwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 3d808e6b8a71ef9fa60b671d07bfd907e9f58efc..5aec433be765dd0a04bd7ab10d5c39a5a7f48c5c 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -42,15 +42,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([0.0, 0.0], var0.eval()) - self.assertAllClose([0.0, 0.0], var1.eval()) + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + self.assertAllClose([0.0, 0.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent. for _ in range(3): update.run() - self.assertAllClose(np.array([-0.9, -1.8]), var0.eval()) - self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) + self.assertAllClose(np.array([-0.9, -1.8]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.09, -0.18]), self.evaluate(var1)) def testProximalGradientDescentwithoutRegularization2(self): with self.cached_session(), self.test_scope(): @@ -64,15 +64,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 3 steps Proximal Gradient Descent for _ in range(3): update.run() - self.assertAllClose(np.array([0.1, 0.2]), var0.eval()) - self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) + self.assertAllClose(np.array([0.1, 0.2]), self.evaluate(var0)) + self.assertAllClose(np.array([3.91, 2.82]), self.evaluate(var1)) def testProximalGradientDescentWithL1(self): with self.cached_session(), self.test_scope(): @@ -86,15 +86,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps proximal gradient descent. for _ in range(10): update.run() - self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval()) - self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) + self.assertAllClose(np.array([-1.988, -3.988001]), self.evaluate(var0)) + self.assertAllClose(np.array([3.67, 2.37]), self.evaluate(var1)) def testProximalGradientDescentWithL1_L2(self): with self.cached_session(), self.test_scope(): @@ -108,15 +108,15 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([4.0, 3.0], self.evaluate(var1)) # Run 10 steps Proximal Gradient Descent for _ in range(10): update.run() - self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval()) - self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval()) + self.assertAllClose(np.array([-0.0495, -0.0995]), self.evaluate(var0)) + self.assertAllClose(np.array([-0.0045, -0.0095]), self.evaluate(var1)) def applyOptimizer(self, opt, steps=5): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) @@ -127,14 +127,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run ProximalAdagrad for a few steps for _ in range(steps): update.run() - return var0.eval(), var1.eval() + return self.evaluate(var0), self.evaluate(var1) def testEquivGradientDescentwithoutRegularization(self): with self.cached_session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 236b1b881dcaffc1a5b0c6395f0605c1d7ef0269..b4d4193e35f9e0e3b23d0242ed076dd811f4ee2b 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -63,7 +63,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. xx = math_ops.matmul(x, x, adjoint_a=True) identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) - precision = self.AdjustedNorm(xx.eval() - identity.eval()) + precision = self.AdjustedNorm(xx.eval() - self.evaluate(identity)) self.assertTrue(np.all(precision < 5.0)) def _test(self, dtype, shape, full_matrices): diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 36ef6ed5fee78bad10bb1ee0bf3eb7824d05c206..97ffad34c00b8ec16eb1ec109ba5d980e0ce673d 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -46,9 +46,9 @@ class RandomOpsTest(xla_test.XLATestCase): # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. - y = sess.run(x) - z = sess.run(x) - w = sess.run(x) + y = self.evaluate(x) + z = self.evaluate(x) + w = self.evaluate(x) # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. @@ -83,7 +83,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) - y = sess.run(x) + y = self.evaluate(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) @@ -102,7 +102,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) - y = sess.run(x) + y = self.evaluate(x) def normal_cdf(x): return .5 * math.erfc(-x / math.sqrt(2)) @@ -111,7 +111,7 @@ class RandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. @@ -148,7 +148,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = range(1 << 16) # Compare sets to avoid randomness behavior changes but make sure still # have all the values. @@ -159,7 +159,7 @@ class RandomOpsTest(xla_test.XLATestCase): with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) - result = sess.run(shuffle) + result = self.evaluate(shuffle) expected = np.diag(range(20)).flatten() # Compare sets to avoid randomness behavior changes but make sure still # have all the values. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index dc119fb0f8a41a3772a8c9508bf2db657f57de88..d23fd125163d1afe8c7fd5e008d4b617ff4b2874 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -2465,20 +2466,21 @@ TEST_F(OpTest, Pack) { }); } -// TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); - // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. - // DataType tpaddings = Choose({DT_INT32, DT_INT64}); - DataType tpaddings = DT_INT32; + DataType tpaddings = Choose({DT_INT32, DT_INT64}); std::vector paddings_vec; - std::uniform_int_distribution distribution(0, 7); for (int i = 0; i < t_dims.size(); ++i) { - paddings_vec.push_back(distribution(generator())); - paddings_vec.push_back(distribution(generator())); + std::uniform_int_distribution pad_distribution(0, t_dims[i]); + int pad_size = pad_distribution(generator()); + std::uniform_int_distribution lower_distribution(0, pad_size); + int low_pad_size = lower_distribution(generator()); + paddings_vec.push_back(low_pad_size); + paddings_vec.push_back(pad_size - low_pad_size); + t_dims[i] -= pad_size; } Tensor paddings; CHECK( @@ -2687,6 +2689,37 @@ TEST_F(OpTest, Reverse) { }); } +TEST_F(OpTest, ReverseSequence) { + Repeatedly([this]() { + std::vector dims = RandomDims(/*min_rank=*/2); + auto type = Choose(kAllXlaTypes); + int64 rank = dims.size(); + + // Choose random batch and sequence dimensions. + std::vector shuffled_dim_ids(rank); + absl::c_iota(shuffled_dim_ids, 0); + absl::c_shuffle(shuffled_dim_ids, generator()); + shuffled_dim_ids.resize(2); + int batch_dim = shuffled_dim_ids[0]; + int seq_dim = shuffled_dim_ids[1]; + + int batch_size = dims[batch_dim]; + int max_seq_len = dims[seq_dim]; + std::vector seq_lens(batch_size); + std::uniform_int_distribution d(0, max_seq_len); + absl::c_generate(seq_lens, [&]() { return d(generator()); }); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ReverseSequence") + .RandomInput(type, dims) + .Input(test::AsTensor(seq_lens)) + .Attr("seq_dim", seq_dim) + .Attr("batch_dim", batch_dim) + .Attr("T", type) + .Attr("Tlen", DT_INT32)); + }); +} + TEST_F(OpTest, ReverseV2) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); @@ -3349,10 +3382,10 @@ int main(int argc, char** argv) { } // XLA devices register kernels at construction time; create all known devices // to make sure the kernels are registered. - std::vector devices; + std::vector> devices; TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( tensorflow::SessionOptions(), "", &devices)); - tensorflow::DeviceMgr device_mgr(devices); + tensorflow::DeviceMgr device_mgr(std::move(devices)); tensorflow::Device* ignored; TF_QCHECK_OK( diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 132c59c32c9db0c8759bdbb31f8613c3ef88b485..e8fc81bbb5472669c408b8bbdbcdfcdcf461131f 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -91,6 +91,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] + ONES = [np.ones([34000, 2])] def testReduceSumF32(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, @@ -149,6 +150,11 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_REAL_DATA, index_dtype) + def testReduceMeanF16(self, index_dtype): + if np.float16 in self.all_types: + self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES, + index_dtype) + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, self.NONEMPTY_COMPLEX_DATA, index_dtype) diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index 8840a1329a907bddc6ef1cb6dd1c2a6d234def5c..dc3e90b4afa41c08d899ee195d42fb91678bad1c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -76,7 +76,7 @@ class RmspropTest(xla_test.XLATestCase): rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered) rms_update = rms_opt.apply_gradients( zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() + self.evaluate(variables.global_variables_initializer()) mg0 = rms_opt.get_slot(var0, "mg") self.assertEqual(mg0 is not None, centered) @@ -92,12 +92,12 @@ class RmspropTest(xla_test.XLATestCase): self.assertTrue(mom1 is not None) # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([3.0, 4.0], var1.eval()) + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) # Run 3 steps of RMSProp for _ in range(3): - rms_update.run() + self.evaluate(rms_update) var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( var0_np, @@ -118,14 +118,14 @@ class RmspropTest(xla_test.XLATestCase): # Validate updated params if centered: - self.assertAllCloseAccordingToType(mg0_np, mg0.eval()) - self.assertAllCloseAccordingToType(mg1_np, mg1.eval()) - self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) - self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) - self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) - self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) - self.assertAllCloseAccordingToType(var0_np, var0.eval()) - self.assertAllCloseAccordingToType(var1_np, var1.eval()) + self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0)) + self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1)) + self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) + self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) + self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) + self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 897db384b7e8067b0460b5f344201f101a4d8479..17639bd8a755b9e9f5acc77979ac7a4149f112db 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -71,7 +71,7 @@ def handle_options(func, x, axis, exclusive, reverse): class CumsumTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) @@ -149,7 +149,7 @@ class CumsumTest(xla_test.XLATestCase): class CumprodTest(xla_test.XLATestCase): - valid_dtypes = [np.float32] + valid_dtypes = [np.float32, np.int32] def axis_dtypes(self): return set(self.int_types).intersection([np.int32, np.int64]) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index dbf4beb693ec1766e6b7b5daaed4be4e1d874fba..3e499c2fb176a6d63fe3590e18a4a90e461e096a 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,13 +48,32 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + def testKeyValueSort(self): + supported_key_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + supported_value_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32, + dtypes.int64.as_numpy_dtype, dtypes.uint64.as_numpy_dtype]) + for key_type in supported_key_types.intersection(self.numeric_types): + for value_type in supported_value_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index e8741bc468585ff9fb049dcd87700f8048d74026..ee7ca7e6f196e114ff18e2597145e5c198980b08 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -23,9 +23,9 @@ import math import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.contrib import stateless from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import test @@ -33,8 +33,11 @@ from tensorflow.python.platform import test class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" - def _random_types(self): - return self.float_types & {dtypes.float32, dtypes.float64} + def _random_types(self, include_int=False): + allowed_types = {dtypes.float32, dtypes.float64, dtypes.bfloat16} + if include_int: + allowed_types.update({dtypes.int32, dtypes.int64}) + return self.all_tf_types & allowed_types def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) @@ -46,6 +49,11 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): ]: for shape in (), (3,), (2, 5): for dtype in self._random_types(): + # Skip bfloat16. The result of bfloat16 is truncated from 32-bit + # result. With different seeds, the 32-bit results are different, + # but the truncated 16-bit results might be the same. + if dtype == dtypes.bfloat16: + continue pure = stateless_op(shape, seed=seed_t, dtype=dtype) values = [(seed, pure.eval(feed_dict={ seed_t: seed @@ -56,13 +64,16 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testRandomUniformIsInRange(self): with self.cached_session() as sess, self.test_scope(): - for dtype in self._random_types(): + for dtype in self._random_types(include_int=True): + maxval = 1 + if dtype.is_integer: + maxval = 100 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( - shape=[1000], seed=seed_t, dtype=dtype) + shape=[1000], seed=seed_t, maxval=maxval, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(y >= 0)) - self.assertTrue(np.all(y < 1)) + self.assertTrue(np.all(y < maxval)) def _chi_squared(self, x, bins): """Pearson's Chi-squared test.""" @@ -75,12 +86,18 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" with self.cached_session() as sess, self.test_scope(): - for dtype in self._random_types(): + for dtype in self._random_types(include_int=True): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 + maxval = 1 + if dtype.is_integer: + maxval = 100 x = stateless.stateless_random_uniform( - shape=[n], seed=seed_t, dtype=dtype) + shape=[n], seed=seed_t, maxval=maxval, dtype=dtype) y = sess.run(x, {seed_t: [565656, 121212]}) + if maxval > 1: + # Normalize y to range [0, 1). + y = y.astype(float) / maxval # Tests that the values are distributed amongst 10 bins with equal # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random @@ -121,7 +138,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # The constant 2.492 is the 5% critical value for the Anderson-Darling # test where the mean and variance are known. This test is probabilistic # so to avoid flakiness the seed is fixed. - self.assertTrue(self._anderson_darling(y) < 2.492) + self.assertTrue(self._anderson_darling(y.astype(float)) < 2.492) def testTruncatedNormalIsInRange(self): for dtype in self._random_types(): @@ -139,7 +156,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) def probit(x, sess=sess): - return sess.run(special_math.ndtri(x)) + return self.evaluate(special_math.ndtri(x)) a = -2. b = 2. @@ -157,6 +174,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Burkardt, John. "The Truncated Normal Distribution". # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + y = y.astype(float) actual_mean = np.mean(y) self.assertAllClose(actual_mean, expected_mean, atol=5e-4) @@ -169,8 +187,8 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - self.assertAllClose(actual_variance, expected_variance, rtol=1e-3) - + self.assertAllClose(actual_variance, expected_variance, + rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 78244d0b366d9128a4c59f786e4c5ac12e743b75..d7e26d79c4c054860ade5c8960a3bca984e020b0 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -79,7 +79,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() self.assertAllEqual( - convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), c0.eval()) + convert([[[4.0, 5.0]], [[6.0, 7.0]], [[8.0, 9.0]]]), + self.evaluate(c0)) def testTensorArrayWritePack(self): for dtype in self.numeric_tf_types: @@ -97,7 +98,7 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.stack() - self.assertAllEqual([3, 0, 1], c0.eval().shape) + self.assertAllEqual([3, 0, 1], self.evaluate(c0).shape) def _testTensorArrayWriteConcat(self, tf_dtype): with self.cached_session(), self.test_scope(): @@ -113,8 +114,8 @@ class TensorArrayTest(xla_test.XLATestCase): c0 = w2.concat() self.assertAllEqual( - convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], - [106.0, 107.0], [8.0, 9.0], [204.0, 205.0]]), c0.eval()) + convert([[4.0, 5.0], [104.0, 105.0], [6.0, 7.0], [106.0, 107.0], + [8.0, 9.0], [204.0, 205.0]]), self.evaluate(c0)) def testTensorArrayWriteConcat(self): for dtype in self.numeric_tf_types: @@ -341,7 +342,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): - r0_bad.eval() + self.evaluate(r0_bad) # Test reading from a different index than the one we wrote to w0.read(1) @@ -422,7 +423,7 @@ class TensorArrayTest(xla_test.XLATestCase): w2 = h2.write(0, 5.0) r2 = w2.read(0) r = r1 + r2 - self.assertAllClose(9.0, r.eval()) + self.assertAllClose(9.0, self.evaluate(r)) def _testTensorArrayGradientWriteReadType(self, dtype): with self.cached_session() as session, self.test_scope(): @@ -504,7 +505,7 @@ class TensorArrayTest(xla_test.XLATestCase): [-0.5, 1.5], # read(0) gradient [20.0, 30.0, 40.0, 50.0], # concat gradient ]) - grad_vals = sess.run(grad_r) # 2 + 2 entries + grad_vals = self.evaluate(grad_r) # 2 + 2 entries self.assertAllClose([2.0 - 0.5 + 20.0, 3.0 + 1.5 + 30.0], grad_vals[0]) self.assertAllEqual([4.0 + 40.0, 5.0 + 50.0], grad_vals[1]) @@ -526,7 +527,7 @@ class TensorArrayTest(xla_test.XLATestCase): with ops.control_dependencies([r0_readtwice]): r1_readtwice = w_readtwice.read(0) - self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) + self.assertAllEqual([1.0, -1.0], self.evaluate(r1_readtwice)) def _testTensorArrayGradientUnpackRead(self): with self.cached_session() as session, self.test_scope(): @@ -592,7 +593,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() - self.assertAllEqual(3, s.eval()) + self.assertAllEqual(3, self.evaluate(s)) def testWriteCloseTensorArray(self): with self.cached_session(), self.test_scope(): @@ -722,7 +723,7 @@ class TensorArrayTest(xla_test.XLATestCase): # r = acc2.stack() # grad = gradients_impl.gradients(r, [x])[0] - # self.assertAllClose(31.0, grad.eval()) + # self.assertAllClose(31.0, self.evaluate(grad)) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): with self.cached_session() as session, self.test_scope(): @@ -912,7 +913,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertEqual(0, ta.size().eval()) ta = ta.unstack(array_ops.zeros([0, 3, 5])) packed = ta.stack() - self.assertAllEqual([0, 3, 5], packed.eval().shape) + self.assertAllEqual([0, 3, 5], self.evaluate(packed).shape) # Concatenating zero tensors along their first dimension gives a # first dimension of zero self.assertAllEqual([0, 5], ta.concat().eval().shape) @@ -920,6 +921,34 @@ class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayEvalEmptyWithDefault(self): self._testTensorArrayEvalEmptyWithDefault() + def _testTensorArrayScatterRead(self, tf_dtype): + with self.cached_session() as session, self.test_scope(): + convert = _make_converter(tf_dtype) + + ta = tensor_array_ops.TensorArray( + dtype=tf_dtype, + tensor_array_name="foo", + size=10) + + indices = constant_op.constant([1, 8]) + value = constant_op.constant(convert([[1.0, -1.0], [10.0, -10.0]])) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) + + w = ta.scatter(indices, value) + r0 = w.read(id0) + r1 = w.read(id1) + + # Test aggregation of read + read_vals = session.run([r0, r1], feed_dict={id0: 1, id1: 8}) + self.assertAllEqual(convert([1.0, -1.0]), read_vals[0]) + self.assertAllEqual(convert([10.0, -10.0]), read_vals[1]) + + def testTensorArrayScatterRead(self): + for dtype in self.numeric_tf_types: + self._testTensorArrayScatterRead(dtype) + self._testTensorArrayScatterRead(dtypes.bool) + def testTensorArrayScatterReadAndGradients(self): with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( @@ -929,15 +958,18 @@ class TensorArrayTest(xla_test.XLATestCase): indices = constant_op.constant([1, 8]) value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) + id0 = array_ops.placeholder(dtypes.int32) + id1 = array_ops.placeholder(dtypes.int32) w = ta.scatter(indices, value) - r0 = w.read(1) - r1 = w.read(8) + r0 = w.read(id0) + r1 = w.read(id1) # Test combined gradients + aggregation of read(0). grad = gradients_impl.gradients( ys=[r0, r1], xs=[value], grad_ys=[[2.0, 3.0], [4.0, 5.0]]) - read_vals, grad_vals = session.run([[r0, r1], grad]) + read_vals, grad_vals = session.run([[r0, r1], grad], + feed_dict={id0: 1, id1: 8}) self.assertEqual(len(read_vals), 2) self.assertEqual(len(grad_vals), 1) @@ -1010,8 +1042,8 @@ class TensorArrayTest(xla_test.XLATestCase): (read0, read1, size0, size1)) # Tests that the control dependencies was added and executed. - self.assertEqual(1, v0.eval()) - self.assertEqual(1, v1.eval()) + self.assertEqual(1, self.evaluate(v0)) + self.assertEqual(1, self.evaluate(v1)) # Tests correct TensorArray. self.assertEqual(read0_v, 0) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index b556723eec77246c87cf88a48c17a307c35fd857..5c079d595c440cac644f5461154509abe7b1d1ed 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -20,22 +20,13 @@ from __future__ import division from __future__ import print_function import numpy as np from tensorflow.compiler.tests import xla_test -from tensorflow.python.client import session -from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import list_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import test -from tensorflow.python.training import server_lib def scalar_shape(): diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py index 6abde18ea91f16d153a154b94effab037a911c6c..0e77dbf1a79d3dbacb77bab8b8e3df9bcc6287e1 100644 --- a/tensorflow/compiler/tests/test_utils.py +++ b/tensorflow/compiler/tests/test_utils.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin def ConvertBetweenDataFormats(x, data_format_src, data_format_dst): @@ -61,3 +62,14 @@ def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst): dim_map = {d: i for i, d in enumerate(data_format_src)} permuted_dims = [dims[dim_map[d]] for d in data_format_dst] return permuted_dims + + +_JIT_WARMUP_ITERATIONS = 10 + + +def RunWithWarmup(sess, op_to_run, feed_dict, options=None, run_metadata=None): + """Runs a graph a few times to ensure that its clusters are compiled.""" + for _ in xrange(0, _JIT_WARMUP_ITERATIONS): + sess.run(op_to_run, feed_dict, options=options) + return sess.run( + op_to_run, feed_dict, options=options, run_metadata=run_metadata) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 77f6eee0cf8ddc9b76f150e1038bf66da34c5218..95c9e7ffd4651642781143c2c1940b0e51e1e470 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -358,6 +358,11 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.leaky_relu, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array([1, 2, 3, 4], dtype=dtype), @@ -476,6 +481,72 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + def quantize_and_dequantize_v2_round_half_up(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_UP") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_up, + np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype), + expected=np.array([ + -102.0 / 127, + -63.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + + def quantize_and_dequantize_v2_round_half_to_even(x): + return array_ops.quantize_and_dequantize_v2( + x, + -1.0, + 1.0, + signed_input=True, + num_bits=8, + range_given=True, + round_mode="HALF_TO_EVEN") + + self._assertOpOutputMatchesExpected( + quantize_and_dequantize_v2_round_half_to_even, + np.array( + [ + -0.8, + # The -0.5 should become -63.5 after scaling and with + # rounding this should become -64. But with the test + # unary_ops_test_cpu_ondemand, this fails as the result + # before scaling becomes -63.499996 and gets rounded to -63. + # TODO(sreenik): Some one more familiar with this test needs + # to take a look and resolve this. This works on all other + # variations of the platform like cpu, and gpu. + # -0.5, + 0, + 0.3, + 0.8, + -2, + 33 + ], + dtype=dtype), + expected=np.array( + [ + -102.0 / 127, + # -64.0 / 127, + 0, + 38.0 / 127, + 102.0 / 127, + -128.0 / 127, + 1, + ], + dtype=dtype)) + def quantize_and_dequantize_v3(x): return array_ops.quantize_and_dequantize_v3( x, -127, 127, num_bits=8, signed_input=True, range_given=False) @@ -724,6 +795,15 @@ class UnaryOpsTest(xla_test.XLATestCase): lambda x: array_ops.bitcast(x, dtypes.int32), np.array([1e-45, 1.0], np.float32), expected=np.array([1, 0x3f800000], np.int32)) + if np.int64 in self.numeric_types: + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int64), + np.array([1, 0x100000003f800000], np.uint64), + expected=np.array([1, 0x100000003f800000], np.int64)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.uint64), + np.array([1, 0x100000003f800000], np.int64), + expected=np.array([1, 0x100000003f800000], np.uint64)) def testInvertPermutation(self): self._assertOpOutputMatchesExpected( diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dd2c252d383bca9c59033ac07e442b487e4975a6..fcd7ac5ba1ca5049246e93e6f5f76746fb28c6b8 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -40,6 +40,19 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(xla_test.XLATestCase): """Test cases for resource variable operators.""" + def testWriteEmptyShape(self): + # Verifies that we can pass an uninitialized variable with an empty shape, + # assign it a value, and successfully return it. + for dtype in self.numeric_types: + with self.test_session() as sess, self.test_scope(): + zeros = np.zeros([3, 0], dtype=dtype) + v = resource_variable_ops.ResourceVariable(zeros) + p = array_ops.placeholder(dtype) + x = v.assign(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(zeros, sess.run(y, {p: zeros})) + def testOneWriteOneOutput(self): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. @@ -64,7 +77,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) self.assertAllClose( - np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x)) + np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x)) def testSparseRead1DIndices(self): for dtype in self.numeric_types: @@ -76,7 +89,7 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([2, 1]) self.assertAllClose( np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices(self): for dtype in self.numeric_types: @@ -89,7 +102,7 @@ class VariableOpsTest(xla_test.XLATestCase): self.assertAllClose( np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), - sess.run(x)) + self.evaluate(x)) def testSparseRead2DIndices3DTensor(self): for dtype in self.numeric_types: @@ -102,9 +115,9 @@ class VariableOpsTest(xla_test.XLATestCase): x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]] - ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] - ],).astype(dtype), sess.run(x)) + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]] + ],).astype(dtype), self.evaluate(x)) def testShape(self): for dtype in self.numeric_types: @@ -216,7 +229,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[3], [7]]) + self.assertAllEqual(self.evaluate(read), [[3], [7]]) def testScatterSub(self): with self.test_session() as sess, self.test_scope(): @@ -229,7 +242,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [1], constant_op.constant([[2]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[4], [-1]]) + self.assertAllEqual(self.evaluate(read), [[4], [-1]]) def testScatterMul(self): with self.test_session() as sess, self.test_scope(): @@ -242,7 +255,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant([[5]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDiv(self): with self.test_session() as sess, self.test_scope(): @@ -255,7 +268,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertAllEqual(sess.run(read), [[2]]) + self.assertAllEqual(self.evaluate(read), [[2]]) def testScatterMin(self): with self.test_session() as sess, self.test_scope(): @@ -268,7 +281,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMax(self): with self.test_session() as sess, self.test_scope(): @@ -281,7 +294,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterUpdate(self): with self.test_session() as sess, self.test_scope(): @@ -294,7 +307,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_update( handle, [0], constant_op.constant([[3]], dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterAddScalar(self): with self.test_session() as sess, self.test_scope(): @@ -307,7 +320,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterSubScalar(self): with self.test_session() as sess, self.test_scope(): @@ -320,7 +333,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_sub( handle, [0], constant_op.constant(2, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[-1]]) + self.assertEqual(self.evaluate(read), [[-1]]) def testScatterMulScalar(self): with self.test_session() as sess, self.test_scope(): @@ -333,7 +346,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_mul( handle, [0], constant_op.constant(5, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[5]]) + self.assertEqual(self.evaluate(read), [[5]]) def testScatterDivScalar(self): with self.test_session() as sess, self.test_scope(): @@ -346,7 +359,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_div( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[2]]) + self.assertEqual(self.evaluate(read), [[2]]) def testScatterMinScalar(self): with self.test_session() as sess, self.test_scope(): @@ -359,7 +372,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_min( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[3]]) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterMaxScalar(self): with self.test_session() as sess, self.test_scope(): @@ -372,7 +385,7 @@ class VariableOpsTest(xla_test.XLATestCase): resource_variable_ops.resource_scatter_max( handle, [0], constant_op.constant(3, dtype=dtypes.int32))) read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(sess.run(read), [[6]]) + self.assertEqual(self.evaluate(read), [[6]]) def testScatterNdAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -387,7 +400,7 @@ class VariableOpsTest(xla_test.XLATestCase): sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) def testScatterNdUpdateAddOps(self): with self.test_session() as sess, self.test_scope(): @@ -403,7 +416,7 @@ class VariableOpsTest(xla_test.XLATestCase): gen_state_ops.resource_scatter_nd_update(handle, indices, updates)) read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.float32) - self.assertAllClose(expected, sess.run(read)) + self.assertAllClose(expected, self.evaluate(read)) class StridedSliceAssignChecker(object): diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 28d61fb07dcb665fa0dbe3f3e566e291e24fa662..ef55292b1be91a731ec556d7efa9cdf1a696e5cc 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -81,7 +81,7 @@ class XlaDeviceTest(xla_test.XLATestCase): with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() - sess.run(x) + self.evaluate(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3f631f91ec442c149b3ea4df3826d98b0419a76f..5a0d9b9af9d55a8dee809d3cf909bce39c3b8b6c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -9,6 +9,7 @@ package_group( "//tensorflow/compiler/jit/...", "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", + "//tensorflow/contrib/compiler/...", ], ) @@ -166,6 +167,7 @@ cc_library( "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", + "xla_expression.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", @@ -180,6 +182,7 @@ cc_library( "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", + "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", @@ -193,6 +196,8 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -200,13 +205,13 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -214,7 +219,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -308,6 +316,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:protos_all_cc", @@ -359,8 +368,12 @@ tf_cc_test( tf_cc_test( name = "xla_compiler_test", - srcs = ["xla_compiler_test.cc"], + srcs = [ + "xla_compiler_test.cc", + "xla_expression_test.cc", + ], deps = [ + ":common", ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", @@ -383,6 +396,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -412,6 +426,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:ops", @@ -424,21 +439,15 @@ cc_library( name = "dump_graph", srcs = [ "dump_graph.cc", - "dump_graph_flags.cc", - "dump_graph_flags.h", ], hdrs = [ "dump_graph.h", ], deps = [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/jit:flags", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", + "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", ], ) @@ -661,5 +670,6 @@ cc_library( hdrs = ["side_effect_util.h"], deps = [ "//tensorflow/core:core_cpu", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 027ca6d2d2f616177d91d9d57d1ff373bab2a754..a57095f91e43f6b31b58e5a5f36331241451b545 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -67,25 +68,18 @@ Status BackwardsConstAnalysis(const Graph& g, } // Mark any compile-time constant operator arguments as const. - const std::unordered_set* const_inputs = - XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); - if (!const_inputs || const_inputs->empty()) return; + std::vector const_input_idxs; + status = XlaOpRegistry::CompileTimeConstantInputs( + node->def(), node->op_def(), &const_input_idxs); - NameRangeMap input_name_ranges; - status = - NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); - if (!status.ok()) return; - - for (const string& input : *const_inputs) { - auto name_range = input_name_ranges.find(input); - if (name_range == input_name_ranges.end()) continue; + if (!status.ok()) { + return; + } - for (Edge const* edge : node->in_edges()) { - if (edge->dst_input() >= name_range->second.first && - edge->dst_input() < name_range->second.second && - edge_filter(*edge)) { - (*compile_time_const_nodes)[edge->src()->id()] = true; - } + for (Edge const* edge : node->in_edges()) { + if (absl::c_binary_search(const_input_idxs, edge->dst_input()) && + edge_filter(*edge)) { + (*compile_time_const_nodes)[edge->src()->id()] = true; } } }; diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 56065be894697bc72ecc0089c665c19aafee7bf8..40c6d0e01701d9104a200d9ea27706a0a7c12146 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -107,5 +108,54 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { EXPECT_EQ(const_args, std::vector({false, true})); } +TEST(ConstAnalysisTest, RespectExplicitAttr_0) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg1, c1); + + // Force const analysis to pretend that the shape argument to `reshape` does + // not need to be a constant. + Output reshape = ops::Reshape(root, arg1, add); + reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, + std::vector()); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); + + EXPECT_EQ(const_args, std::vector({false, false})); +} + +TEST(ConstAnalysisTest, RespectExplicitAttr_1) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg0, c1); + + // Force const analysis to pretend that the first argument to `add` needs to + // be a constant. + std::vector add_constant_inputs; + add_constant_inputs.push_back("x"); + add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(1, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); + + EXPECT_EQ(const_args, std::vector({true})); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 380c6a7e23da92d949b26876836b999bf6406c6c..64fdbbebc65bff4ed0b965fcdd534cc9696472b6 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,87 +18,26 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/util/dump_graph.h" namespace tensorflow { namespace dump_graph { -namespace { - -struct NameCounts { - mutex counts_mutex; - std::unordered_map counts; -}; - -string MakeUniqueFilename(string name) { - static NameCounts& instance = *new NameCounts; - - // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { - char ch = name[i]; - if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { - name[i] = '_'; - } - } - - int count; - { - mutex_lock lock(instance.counts_mutex); - count = instance.counts[name]++; - } - - string filename = name; - if (count > 0) { - absl::StrAppend(&filename, "_", count); - } - absl::StrAppend(&filename, ".pbtxt"); - return filename; -} - -string WriteTextProtoToUniqueFile( - Env* env, const string& name, const char* proto_type, - const ::tensorflow::protobuf::Message& proto) { - const string& dirname = - legacy_flags::GetDumpGraphFlags()->tf_dump_graph_prefix; - Status status = env->RecursivelyCreateDir(dirname); - if (!status.ok()) { - LOG(WARNING) << "Failed to create " << dirname << " for dumping " - << proto_type << ": " << status; - return "(unavailable)"; - } - string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); - status = WriteTextProto(Env::Default(), filepath, proto); - if (!status.ok()) { - LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath - << " : " << status; - return "(unavailable)"; - } - LOG(INFO) << "Dumped " << proto_type << " to " << filepath; - return filepath; -} - -} // anonymous namespace - string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "GraphDef", - graph_def); + return tensorflow::DumpGraphDefToFile( + name, graph_def, GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpGraphToFile(const string& name, Graph const& graph, const FunctionLibraryDefinition* flib_def) { - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - if (flib_def) { - *graph_def.mutable_library() = flib_def->ToProto(); - } - return DumpGraphDefToFile(name, graph_def); + return tensorflow::DumpGraphToFile(name, graph, flib_def, + GetDumpGraphFlags()->tf_dump_graph_prefix); } string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { - return WriteTextProtoToUniqueFile(Env::Default(), name, "FunctionDef", fdef); + return tensorflow::DumpFunctionDefToFile( + name, fdef, GetDumpGraphFlags()->tf_dump_graph_prefix); } } // namespace dump_graph diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc deleted file mode 100644 index a6c908ba011afb90fabacc855df8c6afbb35d254..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include -#include - -#include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static DumpGraphFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new DumpGraphFlags; - flags->tf_dump_graph_prefix = "/tmp/"; - flag_list = new std::vector({ - Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix, - "Path prefix to which graphs dumped during debugging should be " - "written."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h deleted file mode 100644 index 80a3307d920f2cc3d668d507786a02e43589f86f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/dump_graph_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ - -// Legacy flags for the XLA bridge's dump_graph module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// dump_graph module. -void AppendDumpGraphFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// dump_graph module. -typedef struct { - string tf_dump_graph_prefix; // Path prefix to which graphs dumped during - // debugging should be written. -} DumpGraphFlags; - -// Return a pointer to the DumpGraphFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -DumpGraphFlags* GetDumpGraphFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index db256e577a1f3dd38e04d102f60182023b9d43b2..c693e42d26712d55852f45c806215fc1f1b9a030 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -339,6 +339,7 @@ Status Conditional::AddSwitch(Node* s) { DebugString(switch_predicate_), " vs ", DebugString(predicate), ")."); } switches_.insert(s); + parent_->AddSwitchId(s->id()); return Status::OK(); } @@ -695,6 +696,12 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); builder.Attr("Tcond", DT_BOOL); + string outside_compilation; + if (GetNodeAttr(predicate_.node->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } builder.Device(predicate_.node->assigned_device_name()); // Conditional should be the first input ... builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), @@ -1179,7 +1186,7 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) { } void FunctionalizeCond::DeleteReachableAndDeadNodes( - const std::vector& switch_ids, const std::vector& merge_order) { + const std::vector& merge_order) { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been // removed. @@ -1191,7 +1198,7 @@ void FunctionalizeCond::DeleteReachableAndDeadNodes( // All remaining Switch nodes are not reachable from a Merge node and // removed. This is to account for dead Switch nodes. - for (int s_id : switch_ids) { + for (int s_id : switch_ids_) { Node* s = graph_->FindNodeId(s_id); if (s == nullptr) continue; for (const Edge* e : s->out_edges()) { @@ -1282,11 +1289,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // reverse topological sorting); // * Record reverse topological for merge and switch nodes; std::vector rev_topo_order; - std::vector switch_ids; std::vector merge_order; DFS(*graph_, nullptr, [&](Node* n) { if (IsSwitch(n)) { - switch_ids.push_back(n->id()); + AddSwitchId(n->id()); } if (IsMerge(n)) { merge_order.push_back(n); @@ -1300,9 +1306,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { if (merge_order.empty()) { // No merges mean no switch values consumed (as only considering values // fetchable as output of merge); - for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) { - graph_->RemoveNode(graph_->FindNodeId(*it)); - } + DeleteReachableAndDeadNodes(merge_order); return Status::OK(); } @@ -1345,7 +1349,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } - DeleteReachableAndDeadNodes(switch_ids, merge_order); + DeleteReachableAndDeadNodes(merge_order); return Status::OK(); } @@ -1365,6 +1369,10 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { library_); } +void FunctionalizeCond::AddSwitchId(int switch_id) { + switch_ids_.push_back(switch_id); +} + Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 189980894073b1da1a12d1c284536336eb920900..8525d7af61b4471e53a9ae16b081060bfd234c9c 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -166,6 +166,9 @@ class FunctionalizeCond { // Dump graph with the CondState annotated. void DumpGraphWithCondState(const string& name); + // Adds `switch_id` to the list of Switch node ids. + void AddSwitchId(int switch_id); + private: FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); @@ -219,8 +222,7 @@ class FunctionalizeCond { // Deletes all nodes in/consumers reachable from switch/merge nodes that were // extracted. - void DeleteReachableAndDeadNodes(const std::vector& switch_ids, - const std::vector& merge_order); + void DeleteReachableAndDeadNodes(const std::vector& merge_order); // Member used to unique the CondState to a unique CondId (AncestorState to a // unique AncestorId) and keep track of CondState/CondId @@ -234,6 +236,8 @@ class FunctionalizeCond { Graph* graph_; friend class FunctionalizeCondTest; + + std::vector switch_ids_; }; } // namespace functionalize_cond diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 28e09d7b79a70bba7e05e9eccc26a65cc40324c6..3dfd3f854c8646ebbf06d3378201d22e8741b7eb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -75,6 +75,25 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlowForGraphDef(/*lookup_library=*/nullptr, + graph_def, library); +} + +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library) { + FunctionDefLibrary function_lib = graph_def->library(); + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(lookup_library, &graph, library)); + graph.ToGraphDef(graph_def); + std::swap(*graph_def->mutable_library(), function_lib); + return Status::OK(); +} + Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, @@ -94,8 +113,9 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; - // Check if the graph has Switch or Merge node before optimizing the graph. + // Check if the graph has Switch or Merge node. bool has_switch_or_merge = false; for (Node* n : body->graph->nodes()) { if (n->type_string() == "Switch" || n->type_string() == "Merge") { @@ -108,59 +128,14 @@ Status FunctionalizeControlFlowForFunction( // in function body. We still need to rewrite those functions and modify // corresponding nodes. - // Call graph optimizer. The most important optimization we need is constant - // folding, which will replace ops like Shape/BroadcastGradientArgs with - // constant shape input. Without this optimization, those ops might become - // dynamic input for then/else body function and XLA will complain that input - // is not compile time constant. We enable function inlining as well, because - // otherwise we won't be able to infer shape for any node depending on - // function call nodes. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_opt_", func_name), - *body->graph, fld); - } - // Optimizer accepts std::unique_ptr* as input and might change - // underlying pointer, thus we create a new Graph and copy from body->graph. - std::unique_ptr optimized_graph(new Graph(fld)); - CopyGraph(*body->graph, optimized_graph.get()); - OptimizerOptions opts; - opts.set_opt_level(OptimizerOptions::L0); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - auto cf_consider_fn = [](const Node* n) { - // Skip SymbolicGradient op when doing constant folding. - // Enabling SymbolicGradient op in constant folding requires - // flr->device() to be non-null, and here we have not constructed - // proper Device object yet (it will be constructed in XlaCompiler). - return n->type_string() != FunctionLibraryDefinition::kGradientOp; - }; - optimizer.Optimize(flr, flr->env(), - /*device=*/nullptr, &optimized_graph, - /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr, - cf_consider_fn); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_opt_", func_name), - *optimized_graph, fld); - } - // Some inlined functions might have Switch/Merge nodes. - for (Node* n : optimized_graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes // might involve node deletion/addition. Avoid modifying nodes while iterating // it. std::vector>> nodes_to_associated_functions; - for (auto* n : optimized_graph->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, flr); + for (auto* n : g->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); } @@ -215,7 +190,7 @@ Status FunctionalizeControlFlowForFunction( // pointer. That's fine because in that case, associated_functions will // only have one member and the loop will only run once. TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - optimized_graph.get(), n, fld, associated_function, new_name)); + g, n, fld, associated_function, new_name)); } } } @@ -227,21 +202,21 @@ Status FunctionalizeControlFlowForFunction( if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *optimized_graph, fld); + *g, fld); } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); if (VLOG_IS_ON(4)) { dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *optimized_graph, fld); + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); } } if (*modified) { // Add rewritten FunctionDef into library. FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, - &functionalized_fdef)); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); if (func_name == new_func_name) { VLOG(2) << "Replacing function " << func_name; TF_RETURN_IF_ERROR( @@ -270,9 +245,13 @@ Status FunctionalizeControlFlowPass::Run( pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); // Find XLA compile ops and its corresponding FunctionDef. + // TPUCompile op is not in the map because graph rewriting might happen + // multiple times, and we want to avoid functionalize it again. static std::map* kNodeTypeToFunctionAttrMapping = new std::map{ - {"TPUCompile", "function"}, + // TPUReplicate ops are generated by EncapsulateTPUComputationsPass. + {"TPUReplicate", "computation"}, + // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; std::map> canonicalized_name_to_new_name; @@ -282,23 +261,20 @@ Status FunctionalizeControlFlowPass::Run( continue; } const string func_attr = it->second; - if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != - kNodeTypeToFunctionAttrMapping->end()) { - NameAttrList func; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); - VLOG(2) << "Graph has node " << n->type_string() - << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( - absl::StrCat(func.name(), "_f15n_")); - bool modified; - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); - if (modified) { - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); - } + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + bool modified; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); } } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index ba99205640ccdc83a3a4d50e3ec474907894a835..91d33fa405834d7f1f8f66180583580f4f2e448a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -33,6 +33,12 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, + FunctionLibraryDefinition* library); +Status FunctionalizeControlFlowForGraphDef( + const FunctionLibraryDefinition* lookup_library, GraphDef* graph_def, + FunctionLibraryDefinition* library); + // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional // control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c3841f996f801e855da75b23f01d41674ec51c4d..9784985af83a18619d837528f99a60b98a501ec5 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -95,77 +95,87 @@ TEST(FunctionalizeControlFlow, Conditional) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, - then_fn, else_fn); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), + result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -239,75 +249,77 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } -// @function.Defun(noinline=True) -// def increment_fn(x): -// return [x + 1] -// Define the above function, and add it to the given graph. It's used as the -// while loop body in NoinlineLoopBody test. -Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +FunctionDef GetNoinlineFunctionDef() { FunctionDef fdef = FunctionDefHelper::Create( "increment_fn", {"x:int32"}, {"add:int32"}, {}, { @@ -316,8 +328,17 @@ Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { }, {{"add", "add_0:z:0"}}); (*fdef.mutable_attr())["_noinline"].set_b(true); + return fdef; +} + +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { FunctionDefLibrary fdef_lib; - *(fdef_lib.add_function()) = fdef; + *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); NodeDef increment_fn; increment_fn.set_name(node_name); @@ -376,55 +397,88 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { FunctionLibraryDefinition lookup_lib(graph.flib_def()); FunctionLibraryDefinition library(OpRegistry::Global(), {}); // Function increment_fn will be copied from lookup_lib to library. - TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); + *(optimized_graph_def.mutable_library()->add_function()) = + GetNoinlineFunctionDef(); - NameAttrList cond_fn, body_fn; - TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &lookup_lib, &optimized_graph_def, &library)); + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK( + AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } +} - // Body graph. +TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), source); TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - NodeDef retval; - retval.set_name("_retval0_RetVal"); - retval.set_op(FunctionLibraryDefinition::kRetOp); - *retval.add_input() = noinline_node_name; - (*retval.mutable_attr())["T"].set_type(DT_INT32); - (*retval.mutable_attr())["index"].set_i(0); - Status status; - scope.graph()->AddNode(retval, &status); - TF_ASSERT_OK(status); - - GraphDef expected; - TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } - InstantiationResultForTest result; - // Verify that increment_fn has been copied to library. - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + graph_def.clear_library(); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - // Ignore the function library when comparing the graphs. - expected.clear_library(); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } + Status status = + FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library); + EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } // Tests functionalizing OneLoopVar where the loop value is not used post the @@ -467,65 +521,72 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto ten = ops::Const( - scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); - auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( - scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); - auto add = ops::Add(scope.WithOpName("while/add"), identity, one); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -608,86 +669,95 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList cond_fn, body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); - auto while_op = - ops::While(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); - auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); - auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") .WithControlDependencies(arg0.output), - 3); - auto cond_add = - ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const( - scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - - auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); - auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - - auto one = ops::Const( - scope.WithOpName("while/add/one").WithControlDependencies(identity_x), - 1); - auto two = ops::Const( - scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), - 2); + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); - auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(cond_fn.name(), library, &result)); - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = + ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = + ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } @@ -841,177 +911,192 @@ TEST(FunctionalizeControlFlow, Complex) { } FunctionLibraryDefinition library(OpRegistry::Global(), {}); + GraphDef optimized_graph_def; + graph.ToGraphDef(&optimized_graph_def); + TF_ASSERT_OK( + FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - - NameAttrList outer_cond_fn, outer_body_fn; - TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); - - // Outer graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); - auto y = ops::Add(scope.WithOpName("y"), x, three); - - auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, - TensorShape({})); - - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - - auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); - auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // Outer condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto ten = ops::Const( - scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), - 10); - auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Outer body graph. - NameAttrList inner_cond_fn, inner_body_fn; - { - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); - - // Find the inner condition and body names. - TF_EXPECT_OK( - FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( - scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); - auto while_op = - ops::While(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); - - auto one_outer = ops::Const( - scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); - auto add_i = - ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(absl::Span{ - while_op[0].op(), while_op[1].op()}), - identity_i, one_outer); - - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); - auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner condition graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto five = ops::Const( - scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); - auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); - auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; + for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { + NameAttrList outer_cond_fn, outer_body_fn; TF_EXPECT_OK( - InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // Inner body graph. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); - auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); - - auto identity_j = - ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); - auto identity_k = - ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); - - auto mul_jk = - ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); - auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); - auto assign = ops::AssignAddVariableOp( - scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - - auto one = ops::Const( - scope.WithOpName("outer/inner/One") - .WithControlDependencies( - absl::Span{assign.operation}), - 1); - auto add_j = - ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } - auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); - auto retval1 = - ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); - auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), + 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(absl::Span{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = + ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), + result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } } } diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 7c3ad448ef546dd1ab2640a57d7d1d73ca3768ad..d87436a7b4ac37c74d0f0df921779c8716290013 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -523,6 +523,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); + string outside_compilation; + if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, + &outside_compilation) + .ok()) { + builder.Attr(kXlaOutsideCompilationAttrName, outside_compilation); + } std::vector inputs; for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index c019a28e892ff89f559ddbec2360d6caa9c1808f..efb75749722893100494e089c0beb96944e9f1d4 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -51,12 +52,11 @@ namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { - auto builder = ctx->builder(); auto client = ctx->compiler()->client(); - std::vector compile_time_constant_flags(expressions.size()); + std::vector arg_must_be_compile_time_constant(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); @@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); - if (arg.type == DT_RESOURCE) { - return errors::InvalidArgument( - "Resource as function argument is not yet implemented."); - } else if (expressions[i]->has_constant_value()) { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); - } else if (compile_time_constant_flags[i]) { - arg.kind = XlaCompiler::Argument::kConstant; - TF_RET_CHECK(expressions[i]->resource() == nullptr) - << "Input with resource is not yet implemented."; - TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( - expressions[i]->handle())); - TF_ASSIGN_OR_RETURN(auto literal, - client->ComputeConstant(constant_graph)); - TF_RETURN_IF_ERROR( - LiteralToHostTensor(literal, arg.type, &arg.constant_value)); - } else { - arg.kind = XlaCompiler::Argument::kParameter; + switch (expressions[i]->kind()) { + case XlaExpression::Kind::kConstant: + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + break; + case XlaExpression::Kind::kXlaOp: + if (arg_must_be_compile_time_constant[i]) { + TF_ASSIGN_OR_RETURN(absl::optional value, + expressions[i]->ResolveConstant(client)); + if (!value.has_value()) { + return errors::InvalidArgument( + "Argument to function must be a compile-time constant, but " + "unable to resolve argument value to a constant."); + } + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *value; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + break; + case XlaExpression::Kind::kResource: + return errors::Unimplemented( + "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument("Invalid function argument"); } } return Status::OK(); @@ -164,7 +171,7 @@ Status GraphCompiler::Compile() { outputs[o] = op_context.release_output(o); if (outputs[o].tensor == nullptr) { return errors::Internal("Missing xla_context ", o, "-th output from ", - SummarizeNode(*n)); + FormatNodeForError(*n)); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 95a010a119d13d4fdd35690d2b8ea708eafb221f..8bc329229648c5aced8d06c99b170803bb3a90f8 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -40,6 +40,7 @@ tf_kernel_library( "dynamic_stitch_op.cc", "elu_op.cc", "extract_image_patches_op.cc", + "fake_param_op.cc", "fake_quantize_ops.cc", "fft_ops.cc", "fill_op.cc", @@ -120,12 +121,11 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", - "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", @@ -142,10 +142,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -176,6 +177,31 @@ tf_kernel_library( ], ) +# A separate cc_library for resampler_ops is needed because resampler is in +# contrib/, and thus the declaration of resampler cannot be pulled into the deps +# of xla_ops. Therefore, resampler_ops is its own cc_library target, and its +# corresponding tf_kernel_library is defined in contrib/resampler/BUILD. +cc_library( + name = "resampler_ops", + srcs = ["resampler_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + cc_library( name = "conv_op_helpers", srcs = ["conv_op_helpers.cc"], @@ -188,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", "//tensorflow/core:framework", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 276d744c096f8996c774964204feaa3762bdb844..795ea09831e183a26fb3498b9bbaf9c3adaef9ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -48,14 +50,10 @@ class XlaArgOp : public XlaOpKernel { return; } - const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; - if (arg.resource() != nullptr) { - ctx->SetResourceOutput(0, arg.resource()); - } else if (arg.has_constant_value()) { - ctx->SetConstantOutput(0, arg.constant_value()); - } else { - ctx->SetOutput(0, arg.handle()); - } + const XlaExpression& arg = ctx->xla_context()->args()[index_]; + OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, + errors::InvalidArgument("Invalid/missing argument expression")); + ctx->SetOutputExpression(0, arg); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 4cfe946b2e6146f034867c06e996ffae42b90705..1b254e328a8c71bd81a0ec700e2af1d81a5fa67a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" namespace tensorflow { namespace { @@ -28,9 +30,11 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = BatchDot(ctx->Input(0), ctx->Input(1), - /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, - /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); + auto result = + xla::BatchDot(MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(0), adj_x_), adj_x_), + MaybeTransposeInMinorDims( + MaybeConjugate(ctx->Input(1), adj_y_), adj_y_)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a267c0c72fce67d7c22c55a57f8d5ac4ffd2b7e2..0e2f335f3354e3ae6008bdc0ac0b80683fe479c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -115,9 +115,9 @@ class FusedBatchNormGradOp : public XlaOpKernel { // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). auto grad_backprop = - XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype); auto activations = - XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); @@ -151,11 +151,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(scale_dtype); auto converted = - XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -165,19 +165,18 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scratch2 = sum(y_backprop * (x - mean)) auto mul = xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index})); - converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + converted = XlaHelpers::ConvertElementType(mul, accumulation_type); reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); - auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); + auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype); x_backprop = xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index}); scale_backprop = xla::Mul(scratch1, scratch2); } - ctx->SetOutput(0, - XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor()); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index a18e04995b5e1e0b0374f7b0edd6f5e114cf994a..46e5d68c78fd9ff26a88dc2a1484c3a67b76f4f3 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,8 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel { } }; REGISTER_XLA_OP(Name("BatchToSpaceND") - .CompileTimeConstInput("block_shape") - .CompileTimeConstInput("crops"), + .CompileTimeConstantInput("block_shape") + .CompileTimeConstantInput("crops"), BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { @@ -183,7 +183,7 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"), +REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstantInput("crops"), BatchToSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 182f7c99344845964f7010127718f876ab6e8a44..c022284fec6bc91951170e243ea3609c8d5d0c43 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -67,8 +67,8 @@ class BCastArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); }; REGISTER_XLA_OP(Name("BroadcastArgs") - .CompileTimeConstInput("s0") - .CompileTimeConstInput("s1"), + .CompileTimeConstantInput("s0") + .CompileTimeConstantInput("s1"), BCastArgsOp); // Given shapes of two tensors, computes the reduction indices for the @@ -94,14 +94,10 @@ class BCastGradArgsOp : public XlaOpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), errors::InvalidArgument("In[", i, "] must be a vector.", in_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal)); - - BCast::Vec vec; - for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(literal.Get({i})); - } - shapes.push_back(vec); + std::vector vec; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &vec)); + + shapes.push_back(BCast::Vec(vec.begin(), vec.end())); } BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), @@ -126,8 +122,8 @@ class BCastGradArgsOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("BroadcastGradientArgs") - .CompileTimeConstInput("s0") - .CompileTimeConstInput("s1"), + .CompileTimeConstantInput("s0") + .CompileTimeConstantInput("s1"), BCastGradArgsOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 41f540506ba41fbe7f91393e7b8e26a89e72ef0a..e7f369b761f36a717ea5fb536780af91a8955b1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -107,11 +107,11 @@ class BiasAddGradOp : public XlaOpKernel { const DataType accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index a988d3c33ed808b022f67882c8ae5100b7e7a305..5e9280c1fe692037b0a842a92ef5a8c28b854a54 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -43,6 +43,9 @@ namespace { const std::vector& extend_dimensions) override { \ xla::XlaBuilder* b = ctx->builder(); \ (void)b; \ + (void)lhs_shape; \ + (void)rhs_shape; \ + (void)extend_dimensions; \ return HLO; \ } \ }; \ @@ -64,7 +67,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto y_equals_0 = xla::Eq(y, zero); auto zeros = xla::ZerosLike(x); @@ -84,7 +87,7 @@ XLA_MAKE_BINARY(DivNoNan, // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); } @@ -103,30 +106,30 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); +xla::XlaOp XlogyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); } -XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper)); -static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, - xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); - auto zero = XlaHelpers::Zero(b, dtype); +xla::XlaOp XdivyImpl(xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = xla::ZerosLike(x); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); } -XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +XLA_MAKE_BINARY(Xdivy, XdivyImpl(lhs, rhs, broadcast_helper)); // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 696c1c39befd5aa2972afb6cfa64905b57a5ab72..d7a8e67dd33aab5c32b7465ce505b745b5c1ca2f 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,63 +32,13 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector broadcast_dims; - std::vector broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i]) { - broadcast_shape.push_back(output_dims[i]); - } else if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; -REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstantInput("shape"), BroadcastToOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index e7fef77edcba0ea5a521956a704225ac4f7fcb22..7199b9b6feb36dd45ef51f4c38463bc715fcc38a 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -21,10 +21,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -57,42 +60,114 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::XlaBuilder* builder = ctx->builder(); - - std::array uniform_shape_array = { - {batch_size, num_samples, num_classes}}; - xla::PrimitiveType uniform_xla_type; - OP_REQUIRES_OK(ctx, - DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); - xla::Shape uniform_shape = - xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); - auto uniforms = - xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), - XlaHelpers::One(builder, input_type(0)), uniform_shape); + xla::Shape uniform_shape; + int class_dimension; + if (num_samples != 1) { + std::array uniform_shape_array = { + {batch_size, num_samples, num_classes}}; + xla::PrimitiveType uniform_xla_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + uniform_shape = + xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); + class_dimension = 2; + } else { + // Have a special case for when we only need one sample, because + // dimensions may be padded on architectures with tiled memory layouts, so + // if the num_classes or batch size is large then this can lead to + // expensive wasted memory. + std::array uniform_shape_array = {{batch_size, num_classes}}; + xla::PrimitiveType uniform_xla_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + uniform_shape = + xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); + class_dimension = 1; + } + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type)); + xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx); // Use Gumbel softmax trick to generate categorical samples. // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. - auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)), - /*broadcast_dimensions=*/{0, 2}); + auto softmax_entries = + xla::Sub(logits, log_uniforms, + /*broadcast_dimensions=*/{0, class_dimension}); xla::PrimitiveType xla_output_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_type(0), &xla_output_type)); - xla::XlaOp argmax = - XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2); + xla::XlaOp argmax = XlaHelpers::ArgMax(softmax_entries, xla_output_type, + /*axis=*/class_dimension); + if (num_samples == 1) { + argmax = xla::Reshape(argmax, {batch_size, 1}); + } ctx->SetOutput(0, argmax); } + virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, + xla::PrimitiveType type, + XlaOpKernelContext* ctx) { + xla::XlaBuilder* builder = ctx->builder(); + auto uniforms = + xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + return xla::Log(-xla::Log(uniforms)); + } + private: TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); }; // TODO(b/68769717): Rename this sampler to Categorical. -REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"), +REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"), CategoricalOp); +class StatelessCategoricalOp : public CategoricalOp { + public: + explicit StatelessCategoricalOp(OpKernelConstruction* ctx) + : CategoricalOp(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, + XlaOpKernelContext* ctx) override { + xla::XlaOp seed = ctx->Input(2); + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + xla::XlaBuilder* builder = ctx->builder(); + if (uniform_shape.element_type() == xla::BF16) { + uniform_shape.set_element_type(xla::F32); + } + auto uniforms = xla::StatelessRngUniform( + {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT), + XlaHelpers::One(builder, DT_FLOAT)); + return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape seed_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + CategoricalOp::Compile(ctx); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp); +}; + +REGISTER_XLA_OP(Name("StatelessMultinomial") + .CompileTimeConstantInput("num_samples") + .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("Tseed", DT_INT32), + StatelessCategoricalOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 0ae23aa6dfe49048ac5cb8ae00c12432b2e2a2fe..cd7c7f4a82df7a65829787efcb1fd2f77870e945 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -37,16 +38,6 @@ limitations under the License. namespace tensorflow { namespace { -// Used to determine the number of Tensors allowed in a Concat op to prevent -// going over the max gpu parameter memory size. This is an issue because concat -// is variadic and can have an unlimited number of arguments when called. -// Concat ops with more Tensors than this will be split into multiple concat -// ops. -// -// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass -// along with boxing large numbers of parameters. -constexpr int64 kMaxConcatArgsPerOp = 500; - // -------------------------------------------------------------------------- class ConcatBaseOp : public XlaOpKernel { public: @@ -55,15 +46,13 @@ class ConcatBaseOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_tensor_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_tensor_shape.DebugString())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); - // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = literal.Get({}); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_tensor_shape.DebugString())); + int64 concat_dim; + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); std::vector values; std::vector shapes; @@ -73,9 +62,7 @@ class ConcatBaseOp : public XlaOpKernel { const TensorShape& input_shape = shapes[0]; int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; - OP_REQUIRES(ctx, - (0 <= axis && axis < input_dims) || - (allow_legacy_scalars() && concat_dim == 0), + OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, errors::InvalidArgument( "ConcatOp : Expected concatenating dimensions in the range " "[", @@ -84,16 +71,12 @@ class ConcatBaseOp : public XlaOpKernel { // Make a vector holding the XlaOp for each of the inputs that has non-zero // elements. std::vector input_data; - std::vector partial_concats; int output_concat_dim = 0; - const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; - const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( - ctx, - in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), + ctx, in_shape.dims() == input_dims, errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", input_shape.DebugString(), " vs. shape[", i, @@ -105,30 +88,10 @@ class ConcatBaseOp : public XlaOpKernel { input_data.push_back(handle); } output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; - - // Concat is associative, so it can be split into many operations when too - // many arguments are in a single op. This is a temporary workaround for - // b/112613927 where too many parameters in an XlaLaunchOp later result in - // too many parameters to a single GPU kernel. - if (i && i % kMaxConcatArgsPerOp == 0) { - partial_concats.push_back( - xla::ConcatInDim(ctx->builder(), input_data, axis)); - input_data.clear(); - } } - // Add any inputs that have not been put into another concat yet. - partial_concats.insert(partial_concats.end(), input_data.begin(), - input_data.end()); VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; - // Don't add an additional "identity" concatenate for better readibility of - // IR. - if (partial_concats.size() == 1) { - ctx->SetOutput(0, partial_concats.front()); - } else { - ctx->SetOutput(0, - xla::ConcatInDim(ctx->builder(), partial_concats, axis)); - } + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); } private: @@ -149,10 +112,11 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); +REGISTER_XLA_OP(Name("Concat").CompileTimeConstantInput("concat_dim"), + ConcatOp); REGISTER_XLA_OP(Name("ConcatV2") .TypeConstraint("Tidx", DT_INT32) - .CompileTimeConstInput("axis"), + .CompileTimeConstantInput("axis"), ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { @@ -161,11 +125,10 @@ class ConcatOffsetOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape concat_dim_shape = ctx->InputShape(0); - OP_REQUIRES( - ctx, IsLegacyScalar(concat_dim_shape), - errors::InvalidArgument( - "Concat dim tensor should be a scalar integer, but got shape ", - concat_dim_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), + errors::InvalidArgument( + "Concat dim tensor should be a scalar, but got shape ", + concat_dim_shape.DebugString())); for (int i = 1; i < ctx->num_inputs(); ++i) { OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), errors::InvalidArgument("input ", i, @@ -192,39 +155,38 @@ class ConcatOffsetOp : public XlaOpKernel { // [0, 5, 0, 0] const int32 N = ctx->num_inputs() - 1; const TensorShape inp0_shape = ctx->InputShape(1); - xla::Literal inp0_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal)); - const int64 dims = inp0_shape.num_elements(); + std::vector inp0_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &inp0_dims)); + const int64 inp0_rank = inp0_shape.num_elements(); - xla::Literal concat_dim_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = concat_dim_literal.Get({}); + int64 cdim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); - VLOG(1) << "ConcatOffset " << cdim << "," << dims; - int32 axis = cdim < 0 ? cdim + dims : cdim; - OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), + VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; + int32 axis = cdim < 0 ? cdim + inp0_rank : cdim; + OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), errors::InvalidArgument("Concat dim is out of range: ", axis, - " vs. ", dims)); + " vs. ", inp0_rank)); int32 offset = 0; for (int i = 0; i < N; ++i) { const TensorShape inp_shape = ctx->InputShape(1 + i); - OP_REQUIRES(ctx, dims == inp_shape.num_elements(), - errors::InvalidArgument("input ", i, " should contain ", dims, - " elements, but got ", + OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), + errors::InvalidArgument("input ", i, " should contain ", + inp0_rank, " elements, but got ", inp_shape.num_elements())); - xla::Literal inp_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal)); + std::vector inp_dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1 + i, &inp_dims)); - Tensor out_constant(DT_INT32, TensorShape({dims})); + Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); auto out_vec = out_constant.vec(); - for (int64 j = 0; j < dims; ++j) { + for (int64 j = 0; j < inp0_rank; ++j) { if (j == axis) { out_vec(j) = offset; - offset += inp_literal.Get({j}); + offset += inp_dims[j]; } else { - const int32 inp0_element = inp0_literal.Get({j}); - const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES(ctx, (inp0_element == inp_element), + const int32 inp0_element = inp0_dims[j]; + const int32 inp_element = inp_dims[j]; + OP_REQUIRES(ctx, inp0_element == inp_element, errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", inp0_element, " vs. ", inp_element)); @@ -238,8 +200,8 @@ class ConcatOffsetOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ConcatOffset") - .CompileTimeConstInput("concat_dim") - .CompileTimeConstInput("shape"), + .CompileTimeConstantInput("concat_dim") + .CompileTimeConstantInput("shape"), ConcatOffsetOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 2628ef8e2454976aeff3859fa5dc1d8e106f32e1..dff8af800229b9605bb93e0498bc5e5cf012f244 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); - if (proto_.dtype() == DT_STRING) { - LOG(WARNING) << "Not computing Const of type DT_STRING"; - ctx->SetInvalidOutput(0); - return; - } xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index c9a1be494066e4f935a1d818bc86c86333e34fae..641fefafb357f6ad10483c454600f3dadd4f8cb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" @@ -65,60 +64,63 @@ xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 // -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 +// The first step is to create a iota A with iota_dimension = 2 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 +// 0 0 0 0 0 0 0 0 0 0 0 0 +// 1 1 1 1 1 1 1 1 1 1 1 1 +// 2 2 2 2 2 2 2 2 2 2 2 2 // -// and divide B it by 2 to get -// 0 0 1 1 2 2 +// and another iota B with iota_dimension = 3 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 +// 0 1 2 3 4 5 0 1 2 3 4 5 // -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 +// and divide B by 2 to get +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 // -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and B and return the result at the beginning of the +// comment. xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, xla::XlaBuilder* builder) { xla::Shape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dimensions(filter_shape.dimensions_size() - 1); - int64 input_feature = - filter_shape.dimensions(filter_shape.dimensions_size() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. + // Create two iotas with the shape of the expanded filter, one of them with + // the iota dimension chosen as the feature dimension, and the other a iota + // with the iota dimension chosen as the expanded output feature dimension. + std::vector iota_dimensions(expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions); + xla::XlaOp input_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2); + xla::XlaOp expanded_feature_iota = xla::Iota( + builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1); + + // Divide 'expanded_feature_iota' by the depthwise_multiplier to create + // [0 0 1 1 2 2] ... in the example in the function comment. expanded_feature_iota = xla::Div(expanded_feature_iota, XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, depthwise_multiplier)); - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - std::vector expanded_feature_broadcast_dims( - expanded_filter_shape.dimensions().begin(), - expanded_filter_shape.dimensions().end()); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dimensions_size() - 2}); + // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a + // diagonal predicate. + return xla::Eq(expanded_feature_iota, input_feature_iota); } // Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index cd7c820be0b6029514ff74288e7bdd3f75b5d6b1..eafdba876ae9e2c38694f065cf83bb3725b8460e 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/node_def_util.h" @@ -124,7 +124,7 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; REGISTER_XLA_OP( - Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), + Name("Conv2DBackpropInput").CompileTimeConstantInput("input_sizes"), Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { @@ -133,7 +133,7 @@ class Conv3DBackpropInputOp : public ConvBackpropInputOp { : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; REGISTER_XLA_OP( - Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), + Name("Conv3DBackpropInputV2").CompileTimeConstantInput("input_sizes"), Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { @@ -142,7 +142,7 @@ class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") - .CompileTimeConstInput("input_sizes"), + .CompileTimeConstantInput("input_sizes"), DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { @@ -183,7 +183,7 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { } }; REGISTER_XLA_OP( - Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), + Name("Conv2DBackpropFilter").CompileTimeConstantInput("filter_sizes"), Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { @@ -193,7 +193,7 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { } }; REGISTER_XLA_OP( - Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), + Name("Conv3DBackpropFilterV2").CompileTimeConstantInput("filter_sizes"), Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { @@ -202,7 +202,7 @@ class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") - .CompileTimeConstInput("filter_sizes"), + .CompileTimeConstantInput("filter_sizes"), DepthwiseConv2DBackpropFilterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d181a183d412f9c269dd5ec608b388f..234f7b4a019c9aac4bac4f906ddbae166ecd9a80 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 6653944a911588b7bc88d67b8cdd2c17850530f0..516ead4bfe89b4ddeee11dcc6410a838d04f28a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 49c12fc232092873b69961644a059abc6035f64f..ee79cbc70da269be7586c47b4fd33c901f4fd581 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 4af1e8b44cbbd02d8e3ea5e42d841c92288b5d56..bb2c0d9ddb8504a1156a74b6ece5d41b620803c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -102,8 +102,9 @@ class DynamicSliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), - DynamicSliceOp); +REGISTER_XLA_OP( + Name("XlaDynamicSlice").CompileTimeConstantInput("size_indices"), + DynamicSliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cb73053666d4c32bc0a2ef19b174aee1a29f101e..6e6ba21daf5bf3eab5bfc15378e77b6dd253da7c 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -113,8 +113,20 @@ class DynamicStitchOp : public XlaOpKernel { } } int number_of_indices = max_index + 1; - OP_REQUIRES(ctx, number_of_indices > 0, - errors::InvalidArgument("no indices supplied")); + int64 result_rank = 1 + data0_shape.dims() - indices0_shape.dims(); + if (number_of_indices == 0) { + std::vector result_shape(result_rank); + for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { + result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d); + } + xla::PrimitiveType element_type = + ctx->input_xla_type(ctx->num_inputs() - 1); + xla::Literal empty_literal = xla::Literal::CreateFromShape( + xla::ShapeUtil::MakeShape(element_type, result_shape)); + ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal)); + return; + } + // Construct the reverse mapping, for each index, of which slice of which // input it comes from. std::vector src_input_vector(number_of_indices); @@ -157,12 +169,9 @@ class DynamicStitchOp : public XlaOpKernel { // Set up the vectors for slicing: the first dimension will vary // slice by slice, and the rest take the full common extra shape. - std::vector slice_start(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector slice_limit(1 + data0_shape.dims() - - indices0_shape.dims()); - std::vector stride(1 + data0_shape.dims() - indices0_shape.dims(), - 1); + std::vector slice_start(result_rank); + std::vector slice_limit(result_rank); + std::vector stride(result_rank, 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -200,10 +209,11 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"), - DynamicStitchOp); -REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"), +REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstantInput("indices"), DynamicStitchOp); +REGISTER_XLA_OP( + Name("ParallelDynamicStitch").CompileTimeConstantInput("indices"), + DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index c68b0bfd7961892294c2931e5c4c44de534a7740..29687c7b82f92d9f336854c4575746589c63b64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec3463bd58f55c1fc6a8f7c074c8e487d266d7b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +// This OpKernel implements the FakeParam Op for XLA JIT devices. Create zeros +// with the appropriate shape for FakeParam op. +class XlaFakeParamOp : public XlaOpKernel { + public: + explicit XlaFakeParamOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType dtype; + TensorShape tensor_shape; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Zeros(b, shape_)); + } + + private: + xla::Shape shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaFakeParamOp); +}; + +REGISTER_XLA_OP(Name("FakeParam"), XlaFakeParamOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index cdba6680dee3fade5bdf0c453ed672b653072b0d..142be030f737f105980ab9c80a5a849e1ca6eb47 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -260,19 +260,19 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { xla::XlaOp below_min = xla::Lt(input, nudged_input_min); xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes); xla::XlaOp reduce1 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::ConvertElementType(select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type); ctx->SetOutput(1, output1); xla::XlaOp above_max = xla::Gt(input, nudged_input_max); xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes); xla::XlaOp reduce2 = xla::ReduceAll( - XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::ConvertElementType(select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index 80bcef966360ec9a1ca63a02741108ce41b31846..6df8b5367d2390e65995beb1583b225755e6ee9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -50,11 +51,36 @@ class GenericFftOp : public XlaOpKernel { errors::InvalidArgument("input must be at least 1 dimensional")); std::vector fft_length; + xla::XlaOp input = ctx->Input(0); if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); OP_REQUIRES(ctx, fft_length.size() == fft_rank_, errors::InvalidArgument("fft_length must be length ", fft_rank_, " vector")); + + // Zero pad or truncate the axes we're doing FFT on. + absl::InlinedVector slice_sizes = input_shape.dim_sizes(); + std::vector> padding_sizes(slice_sizes.size()); + std::vector expected_sizes = fft_length; + // IRFFT wants the innermost axis to be n / 2 + 1. + if (fft_type_ == FftType::IRFFT) { + expected_sizes[fft_rank_ - 1] = fft_length[fft_rank_ - 1] / 2 + 1; + } + for (int i = 0; i < fft_rank_; i++) { + int index = input_shape.dims() - fft_rank_ + i; + if (input_shape.dim_size(index) > expected_sizes[i]) { + slice_sizes[index] = expected_sizes[i]; + } else { + padding_sizes[index].second = + expected_sizes[i] - input_shape.dim_size(index); + } + } + + std::vector start_indices(input_shape.dims(), 0); + std::vector strides(input_shape.dims(), 1); + input = xla::Pad(xla::Slice(input, start_indices, slice_sizes, strides), + XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), + xla::MakeEdgePaddingConfig(padding_sizes)); } else { // Innermost axis provides the FFT length. for (int i = 0; i < fft_rank_; i++) { @@ -63,7 +89,7 @@ class GenericFftOp : public XlaOpKernel { } } - xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaOp fft = xla::Fft(input, fft_type_, fft_length); ctx->SetOutput(0, fft); } @@ -106,9 +132,11 @@ class RFFTOp : public GenericFftOp { explicit RFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("RFFT").CompileTimeConstInput("fft_length"), RFFTOp<1>); -REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstInput("fft_length"), RFFTOp<2>); -REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstInput("fft_length"), RFFTOp<3>); +REGISTER_XLA_OP(Name("RFFT").CompileTimeConstantInput("fft_length"), RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstantInput("fft_length"), + RFFTOp<2>); +REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstantInput("fft_length"), + RFFTOp<3>); template class IRFFTOp : public GenericFftOp { @@ -116,10 +144,11 @@ class IRFFTOp : public GenericFftOp { explicit IRFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstInput("fft_length"), IRFFTOp<1>); -REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstantInput("fft_length"), + IRFFTOp<1>); +REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstantInput("fft_length"), IRFFTOp<2>); -REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstInput("fft_length"), +REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstantInput("fft_length"), IRFFTOp<3>); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 54b21a278229024e3e54e9135548be6b69b077e1..35e0625dbb0d4c696d36cce642d6f50f1d220c45 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -33,44 +34,25 @@ class FillOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // The output of this Op is a tensor of shape 'dims_shape' with each // element set to the scalar 'dims_literal'. - const TensorShape dims_shape = ctx->InputShape(0); - const TensorShape value_shape = ctx->InputShape(1); + const TensorShape dims_shape = ctx->InputShape("dims"); + const TensorShape value_shape = ctx->InputShape("value"); OP_REQUIRES( - ctx, IsLegacyVector(dims_shape), + ctx, TensorShapeUtils::IsVector(dims_shape), errors::InvalidArgument("dims must be a vector of int32, got shape ", dims_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(value_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(value_shape), errors::InvalidArgument("value must be a scalar, got shape ", value_shape.DebugString())); - // Evaluate the 'dims' constant input, reshaping to a vector if it - // was a 'legacy' vector (secretly a scalar). - xla::Literal dims_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped( - 0, {dims_shape.num_elements()}, &dims_literal)); - // Convert the dims literal into a vector that we can pass to - // XlaBuilder. - std::vector broadcast; - broadcast.reserve(dims_literal.shape().dimensions(0)); - for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(dims_literal.Get({i})); - } - // Look up the value input, reshaping to a scalar if it was a - // 'legacy' scalar (secretly a vector). - xla::XlaOp data = ctx->Input(1); - if (value_shape.dims() > 0) { - CHECK_EQ(value_shape.dims(), 1); - data = xla::Reshape(data, {}); - } - // Emit the actual computation, which broadcasts the scalar to the - // desired shape. - auto result = xla::Broadcast(data, broadcast); + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("dims", &dims)); + auto result = xla::Broadcast(ctx->Input("value"), dims); ctx->SetOutput(0, result); } }; -REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp); +REGISTER_XLA_OP(Name("Fill").CompileTimeConstantInput("dims"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 44140304fdf5cdf60d8ad8b85c532fcadff8ba86..20b0de193dc060197f3062d3be0b8d45f7dcb9b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -194,7 +194,7 @@ class GatherOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("Gather"), GatherOp); -REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp); class GatherNdOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 56da50f140893c68c8a1556853884720b21c7229..b5e083912555c865b5eadc7697075c9ca4451ca9 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -72,7 +72,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.shape = resource->shape(); OP_REQUIRES(ctx, arg.initialized, errors::Unimplemented("Uninitialized arguments: ", arg.name)); - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 921b4340c0ac674a5ad7d17aaf54f1cf36975151..e9bb0a77e99d144863b027bd214081316d61c314 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -189,12 +191,11 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); - auto converted = - XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); - auto output = XlaHelpers::ConvertElementType(b, reduce, type); + auto output = XlaHelpers::ConvertElementType(reduce, type); output = xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); @@ -316,6 +317,70 @@ class AdjustHueOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); +struct WhileCondFn { + const int64 num_boxes; + const int64 output_size; + + explicit WhileCondFn(int64 num_boxes, int64 output_size) + : num_boxes(num_boxes), output_size(output_size) {} + + xla::StatusOr operator()(absl::Span values, + xla::XlaBuilder* cond_builder) const { + xla::XlaOp row_idx = values[0]; + xla::XlaOp row_in_bounds = + xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); + xla::XlaOp num_outputs_so_far = values[1]; + xla::XlaOp results_not_full = xla::Lt( + num_outputs_so_far, xla::ConstantR0(cond_builder, output_size)); + return xla::And(row_in_bounds, results_not_full); + } +}; + +// Process the boxes one-by-one using the iou matrix mask. +// This implementation uses a correct, but greedy, sequential algorithm +// to ensure that suppressed boxes cannot themselves suppress other +// boxes. +struct SuppressBodyFn { + const int64 num_boxes; + + explicit SuppressBodyFn(int64 num_boxes) : num_boxes(num_boxes) {} + + xla::StatusOr> operator()( + absl::Span values, xla::XlaBuilder* builder) const { + auto row_idx = values[0]; + auto num_outputs_so_far = values[1]; + auto iou_mask = values[2]; + auto included_iou = values[3]; + auto zero_r1 = xla::ConstantR1(builder, {0}); + // Determine if current elem is active using a slice. + auto row_idx_r1 = xla::Reshape(row_idx, {1}); + auto active_elem = xla::DynamicSlice(included_iou, row_idx_r1, {1}); + active_elem = xla::Reshape(active_elem, {}); + // Increment output count iff current elem is not suppressed. + num_outputs_so_far = xla::Select( + active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), + num_outputs_so_far); + // Slice out the row_idx. + auto starts = xla::ConcatInDim(builder, {row_idx_r1, zero_r1}, 0); + auto row_iou = xla::DynamicSlice(iou_mask, starts, {1, num_boxes}); + // Remove the diagonal from consideration. An elem cannot suppress + // itself. + auto update_starts = xla::ConcatInDim(builder, {zero_r1, row_idx_r1}, 0); + row_iou = xla::DynamicUpdateSlice( + row_iou, xla::ConstantR2FromArray2D(builder, {{false}}), + update_starts); + // Create a suppression by inverting polarity. + row_iou = xla::Reshape(row_iou, {num_boxes}); + auto supp_mask = xla::Not(row_iou); + // Update mask iff current elem is not suppressed. + included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), + xla::And(included_iou, supp_mask), included_iou); + row_idx = row_idx + xla::ConstantR0(builder, 1); + return std::vector{row_idx, num_outputs_so_far, iou_mask, + included_iou}; + } +}; + class NonMaxSuppressionOp : public XlaOpKernel { public: explicit NonMaxSuppressionOp(OpKernelConstruction* context) @@ -326,14 +391,12 @@ class NonMaxSuppressionOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { // TODO(b/111646731): Improve scalability of this op, using blocking. - int num_boxes_dim = 0; - int coords_dim = 1; const TensorShape& boxes_shape = context->InputShape("boxes"); OP_REQUIRES(context, TensorShapeUtils::IsMatrix(boxes_shape), errors::InvalidArgument("boxes must be 2-D, currently: ", boxes_shape.DebugString())); - const int64 num_boxes = boxes_shape.dim_size(num_boxes_dim); - OP_REQUIRES(context, boxes_shape.dim_size(coords_dim) == 4, + const int64 num_boxes = boxes_shape.dim_size(0); + OP_REQUIRES(context, boxes_shape.dim_size(1) == 4, errors::InvalidArgument("boxes must have 4 columns", boxes_shape.DebugString())); const TensorShape& scores_shape = context->InputShape("scores"); @@ -347,9 +410,13 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES(context, pad_to_max_output_size_, errors::InvalidArgument( "XLA compilation requires pad_to_max_output_size == True")); + OP_REQUIRES(context, num_boxes <= kint32max, + errors::InvalidArgument("XLA compilation requires number of " + "boxes to be <= kint32max, got ", + num_boxes)); - xla::XlaOp boxes = context->Input("boxes"); - xla::XlaOp scores = context->Input("scores"); + const xla::XlaOp boxes_input = context->Input("boxes"); + const xla::XlaOp scores_input = context->Input("scores"); int64 output_size; OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &output_size)); OP_REQUIRES( @@ -358,90 +425,113 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES(context, output_size <= kint32max, errors::InvalidArgument("Need output_size <= kint32Max, got ", output_size)); - xla::XlaOp score_thresh = context->Input("score_threshold"); - xla::XlaOp iou_thresh = context->Input("iou_threshold"); - + const xla::XlaOp score_thresh = context->Input("score_threshold"); + const xla::XlaOp iou_thresh = context->Input("iou_threshold"); xla::XlaBuilder* const builder = context->builder(); // Choose a more convenient layout. - xla::XlaOp boxes_t = xla::Transpose(boxes, {1, 0}); - coords_dim = 0; - num_boxes_dim = 1; - - // Shapes are henceforth [1, num_boxes]. - xla::XlaOp coord_y0 = xla::SliceInDim(boxes_t, - /*start_index=*/0, - /*limit_index=*/1, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x0 = xla::SliceInDim(boxes_t, - /*start_index=*/1, - /*limit_index=*/2, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_y1 = xla::SliceInDim(boxes_t, - /*start_index=*/2, - /*limit_index=*/3, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp coord_x1 = xla::SliceInDim(boxes_t, - /*start_index=*/3, - /*limit_index=*/4, - /*stride=*/1, - /*dimno=*/coords_dim); - xla::XlaOp y1 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y0, coord_y1); - xla::XlaOp y2 = - xla::Select(xla::Le(coord_y0, coord_y1), coord_y1, coord_y0); - xla::XlaOp x1 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x0, coord_x1); - xla::XlaOp x2 = - xla::Select(xla::Le(coord_x0, coord_x1), coord_x1, coord_x0); + const xla::XlaOp boxes = xla::Transpose(boxes_input, {1, 0}); + const xla::XlaOp boxes_sorted = xla::GetTupleElement( + xla::Sort(/*keys=*/-xla::Broadcast(scores_input, {4}), + /*values=*/{boxes}, + /*dimension=*/1), + 1); + // Track the mapping of indices into sorted domain. + const xla::XlaOp iota_indices = xla::Iota(builder, xla::S32, num_boxes); + const xla::XlaOp indices_sort = xla::Sort(-scores_input, {iota_indices}); + const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); + const xla::XlaOp scores = xla::Neg(xla::GetTupleElement(indices_sort, 0)); + + // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. + const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/0, + /*limit_index=*/1, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x0 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/1, + /*limit_index=*/2, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_y1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/2, + /*limit_index=*/3, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + const xla::XlaOp c_x1 = xla::Reshape(xla::SliceInDim(boxes_sorted, + /*start_index=*/3, + /*limit_index=*/4, + /*stride=*/1, + /*dimno=*/0), + {num_boxes}); + + xla::XlaOp y1 = xla::Select(xla::Le(c_y0, c_y1), c_y0, c_y1); + xla::XlaOp y2 = xla::Select(xla::Le(c_y0, c_y1), c_y1, c_y0); + xla::XlaOp x1 = xla::Select(xla::Le(c_x0, c_x1), c_x0, c_x1); + xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); xla::XlaOp area = (y2 - y1) * (x2 - x1); - // Transpose the 1xN tensors, instead of the NxN tensors. - xla::XlaOp y1_t = xla::Transpose(y1, {1, 0}); - xla::XlaOp y2_t = xla::Transpose(y2, {1, 0}); - xla::XlaOp x1_t = xla::Transpose(x1, {1, 0}); - xla::XlaOp x2_t = xla::Transpose(x2, {1, 0}); - xla::XlaOp area_t = xla::Transpose(area, {1, 0}); + // Shapes are henceforth [1, num_boxes]. + y1 = xla::Broadcast(y1, {1}); + y2 = xla::Broadcast(y2, {1}); + x1 = xla::Broadcast(x1, {1}); + x2 = xla::Broadcast(x2, {1}); + area = xla::Broadcast(area, {1}); // Shapes are henceforth [num_boxes, num_boxes]. - xla::XlaOp i_xmin = xla::Max(x1, x1_t); - xla::XlaOp i_ymin = xla::Max(y1, y1_t); - xla::XlaOp i_xmax = xla::Min(x2, x2_t); - xla::XlaOp i_ymax = xla::Min(y2, y2_t); + xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); + xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); + xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); + xla::XlaOp i_ymax = xla::Min(y2, xla::Transpose(y2, {1, 0})); auto square_zero = xla::ZerosLike(i_xmin); xla::XlaOp i_area = xla::Max(i_xmax - i_xmin, square_zero) * xla::Max(i_ymax - i_ymin, square_zero); - xla::XlaOp u_area = area + area_t - i_area; + xla::XlaOp u_area = area + xla::Transpose(area, {1, 0}) - i_area; xla::XlaOp iou = i_area / u_area; xla::XlaOp iou_thresh_mask = xla::Gt(iou, iou_thresh + square_zero); - xla::XlaOp scores_2d = xla::Reshape(scores, {num_boxes, 1}); - xla::XlaOp score_cmp_mask = - xla::Gt(scores_2d, xla::Transpose(scores_2d, {1, 0})); - xla::XlaOp suppress = xla::And(iou_thresh_mask, score_cmp_mask); - - // Shapes are [num_boxes] after the reduce. - xla::XlaOp included_iou = xla::Not(xla::Reduce( - suppress, - /*init_value=*/xla::ConstantR0(builder, false), - /*computation=*/CreateScalarOrComputation(xla::PRED, builder), - /*dimensions_to_reduce=*/{0})); + xla::XlaOp included_iou = + xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); + + std::vector init_values; + init_values.reserve(4); + init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx + init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs + init_values.push_back(iou_thresh_mask); + init_values.push_back(included_iou); + + auto suppress_loop_result = + XlaWhileLoop(WhileCondFn(num_boxes, output_size), + SuppressBodyFn(num_boxes), init_values, "suppress_loop", + builder) + .ValueOrDie(); + xla::XlaOp included_score = xla::Gt(scores, xla::Broadcast(score_thresh, {num_boxes})); - xla::XlaOp included = xla::And(included_iou, included_score); + xla::XlaOp included = xla::And(included_score, suppress_loop_result[3]); + + // Only consider boxes over which we have iterated. This allows for accurate + // counting. DynamicSlice would require knowledge of the size of the output. + auto valid_elem = xla::Lt( + iota_indices, xla::Broadcast(suppress_loop_result[0], {num_boxes})); + included = xla::And(included, valid_elem); + xla::XlaOp neg_inf = xla::Broadcast(xla::MinValue(builder, xla::F32), {num_boxes}); xla::XlaOp scores_included = xla::Select(included, scores, neg_inf); - + xla::XlaOp output_tuple = TopK(scores_included, output_size); + xla::XlaOp selected_indices_sorted = xla::GetTupleElement(output_tuple, 1); + // Calculate num_valid. + // Note: num_valid cannot be taken from the loop outputs, because outputs + // can be suppressed by score threshold. xla::XlaOp ones_included = xla::Select( included, xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); - // num_valid is scalar. Value should be bound by output_size. xla::XlaOp num_valid_total = xla::Reduce( ones_included, @@ -451,8 +541,17 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::XlaOp num_valid = xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); - xla::XlaOp output_tuple = TopK(scores_included, output_size); - xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); + // Re-index into the original scores input tensor, using a Gather. + // Boxes were suppressed in the sorted domain. + xla::XlaOp selected_indices; + DataType gather_type = context->expected_output_dtype(0); + OP_REQUIRES_OK( + context, + XlaGather(indices_sorted, scores_shape, selected_indices_sorted, + TensorShape({output_size}), + /*axis=*/0, + /*indices_are_nd=*/false, + /*dtype=*/gather_type, DT_INT32, builder, &selected_indices)); context->SetOutput(0, selected_indices); context->SetOutput(1, num_valid); @@ -463,7 +562,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { }; REGISTER_XLA_OP( - Name("NonMaxSuppressionV4").CompileTimeConstInput("max_output_size"), + Name("NonMaxSuppressionV4").CompileTimeConstantInput("max_output_size"), NonMaxSuppressionOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 7b2bb4a7c50fc954237e09a32f71009f790b60d0..5a10c52ba8b6d4fab73f0dda67cbd52fd625e76b 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -231,20 +230,22 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + const int64 batch_dim_size = + builder->GetShape(input).ValueOrDie().dimensions(0); if (num_extended[0] > 0) { - auto slice = - xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, - {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + auto slice = xla::Slice( + input_data, {0, in_size[0] - 1, 0, 0}, + {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); for (int i = 0; i < num_extended[0]; i++) { input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); } } if (num_extended[1] > 0) { - auto slice = - xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, - {1, in_size[0] + num_extended[0], in_size[1], channels}, - {1, 1, 1, 1}); + auto slice = xla::Slice( + input_data, {0, 0, in_size[1] - 1, 0}, + {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); for (int i = 0; i < num_extended[1]; i++) { input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); } @@ -511,7 +512,7 @@ class ResizeBilinearOp : public XlaOpKernel { bool align_corners_; }; -REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), +REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"), ResizeBilinearOp); class ResizeBilinearGradOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index f3964748587c1b31cf8b1b76643ff19a9044bf44..843b6bb4e658af16fd753c1a20b35dd3d18df027 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -78,7 +78,7 @@ XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} REGISTER_XLA_OP(Name("ArgMax") .Device(DEVICE_GPU_XLA_JIT) - .CompileTimeConstInput("dimension"), + .CompileTimeConstantInput("dimension"), XlaArgMaxOp); namespace { @@ -89,7 +89,8 @@ class XlaArgMinOp : public XlaArgMinMaxOp { }; XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} -REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp); +REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstantInput("dimension"), + XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 3d81ae9eb89a80e5b89b180ad77521c5ed15e79d..e2c05b648bb194b1b452c527ddb1a2c5995b1217 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -30,7 +30,9 @@ limitations under the License. namespace tensorflow { namespace { -// The logic below uses a custom-call to implement argmax. +// The logic below uses a custom-call to implement argmax when possible. When +// custom-call is not allowed or input shapes are not supported, this kernel +// falls back to using XLA HLO native ArgMax. // // Also see b/29507024 for first-class XLA support for indexing ops. class ArgMaxCustomCallOp : public XlaOpKernel { @@ -48,30 +50,42 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // We require that the dimension argument is a constant, since it lets us // dispatch to a specialized custom-call function without any run-time // overhead, when compiling ahead-of-time. - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = literal.Get({}); - OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); - OP_REQUIRES( - ctx, dim < input_shape.dims(), - errors::InvalidArgument("dim must be < input rank (", - input_shape.dims(), "), but got: ", dim)); - const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES(ctx, dim_size > 0, + int64 dim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); + + const int input_dims = input_shape.dims(); + const int axis = dim < 0 ? dim + input_dims : dim; + OP_REQUIRES(ctx, axis >= 0 && axis < input_dims, + errors::InvalidArgument("Expected dimension in the range [", + -input_dims, ", ", input_dims, + "), but got ", dim)); + + const int64 axis_size = input_shape.dim_size(axis); + OP_REQUIRES(ctx, axis_size > 0, errors::InvalidArgument( "Reduction axis ", dim, " is empty in shape: ", input_shape.DebugString())); - // The output shape is the input shape contracted along dim. + const DataType dtype = output_type(0); + xla::PrimitiveType output_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type)); + + // Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input + // shape isn't supported. + if (!ctx->compiler()->options().allow_cpu_custom_calls || + (input_dims != 1 && input_dims != 2)) { + xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis); + ctx->SetOutput(0, output); + return; + } + + xla::XlaOp output; + // The output shape is the input shape contracted along axis. TensorShape output_shape; for (int d = 0; d < input_shape.dims() - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); } - // For now we use a custom-call, only for the 1d and 2d cases. - OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), - errors::InvalidArgument( - "ArgMax implementation requires a CustomCall on CPU")); xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. @@ -85,31 +99,32 @@ class ArgMaxCustomCallOp : public XlaOpKernel { args.push_back(xla::ConstantLiteral( &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(axis))); } - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + // The argmax function expects row-major layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::S64, output_shape.dim_sizes()); + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + auto shape_status = b.GetShape(arg); + OP_REQUIRES_OK(ctx, shape_status.status()); + xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); + *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( + xla::ShapeUtil::Rank(arg_shape)); + arg_shapes.push_back(std::move(arg_shape)); + } // Tell XLA to call the custom code, defined in - // index_ops_kernel_argmax_float_1d.cc. - xla::XlaOp output; - switch (input_shape.dims()) { - case 1: - output = - xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); - break; - case 2: - output = - xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); - break; - default: - OP_REQUIRES(ctx, false, - errors::Unimplemented( - "Argmax is only implemented for 1d and 2d tensors" - ", but got shape: ", - input_shape.DebugString())); + // index_ops_kernel_argmax_float_{1, 2}d.cc. + if (input_dims == 1) { + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); + } else { + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); } + output = xla::ConvertElementType(output, output_type); ctx->SetOutput(0, output); } @@ -120,7 +135,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ArgMax") .TypeConstraint("T", DT_FLOAT) .Device(DEVICE_CPU_XLA_JIT) - .CompileTimeConstInput("dimension"), + .CompileTimeConstantInput("dimension"), ArgMaxCustomCallOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index f028e361bccd51de0bd69a1d2227c7afaed53455..93f029731c34e84000a3dc00df8af05654cccf2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -37,12 +37,11 @@ class L2LossOp : public XlaOpKernel { // output = sum(t ** 2) / 2 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto t = - XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); auto square = xla::Mul(t, t); auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), dims); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto deconverted = XlaHelpers::ConvertElementType(reduce, dtype); auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); ctx->SetOutput(0, xla::Div(deconverted, two)); } diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index a11bbe918f7f8eb050aaa40d4344f9cc9e9a10a4..e46f4e72dc9cb245916b138d5365ee42371f0e4c 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -115,8 +115,8 @@ class ListDiffOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ListDiff") .TypeConstraint("T", kListDiffTypes) - .CompileTimeConstInput("x") - .CompileTimeConstInput("y"), + .CompileTimeConstantInput("x") + .CompileTimeConstantInput("y"), ListDiffOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 87ee2d3aede50eb24e65570f106d49030e1d4236..987901d82b3f3798dd52f18ef2497b8f0cf80b11 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -49,16 +49,14 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); - auto converted = - XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto scale = xla::Pow( xla::Add(xla::ConstantR0(builder, bias_), @@ -138,15 +136,14 @@ class LRNGradOp : public XlaOpKernel { auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = - XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + XlaHelpers::ConvertElementType(in_image, accumulation_type); auto squared = xla::Mul(converted, converted); auto reduce = xla::ReduceWindow( squared, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto sqr_sum = - XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); + auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); auto norm = xla::Add(xla::ConstantR0(builder, bias_), @@ -157,15 +154,13 @@ class LRNGradOp : public XlaOpKernel { xla::Div(out_image, norm)), in_grads); - auto converted_dy = - XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); auto dy_reduce = xla::ReduceWindow( converted_dy, XlaHelpers::Zero(builder, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); - auto dy_reduced = - XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); + auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0)); xla::XlaOp gradients = xla::Add( xla::Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index 8dfd7de591c4a3c4768dd60b41e03d294ad49397..2dd0a710e47ec8cad6153402fdb3be59f5868205 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -61,11 +61,11 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m); - xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n); + xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n}); + xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0); + xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1); - auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m, - /*broadcast_dimensions=*/{0}); + auto offset = xla::Sub(iota_n, iota_m); // If num_lower or num_upper are negative, include all lower/upper // diagonals. diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index c0ca881ff82cee04e0c5e35f9a2d5732fabdd8a6..4f980b6d14ed667bdf4756ed740894098cae5919 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index f4def11d08c31513aec5aad15187016a7294c2fd..90c0ebefb24ec2c4378782e9b15d3f57c33032a4 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" namespace tensorflow { namespace { @@ -29,7 +29,7 @@ class MatrixTriangularSolveOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = TriangularSolve( + auto result = xla::TriangularSolve( ctx->Input(0), ctx->Input(1), /*left_side=*/true, /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6..f6b8534f4d7c537e5b708ee000e00cb92123584b 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -41,10 +41,8 @@ class MirrorPadOp : public XlaOpKernel { for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = xla::Rev(accum, {dimno}); - TF_ASSIGN_OR_RETURN(int64 lhs_padding, - pad_literal.GetIntegralAsS64({dimno, 0})); - TF_ASSIGN_OR_RETURN(int64 rhs_padding, - pad_literal.GetIntegralAsS64({dimno, 1})); + int64 lhs_padding = pad_literal.Get({dimno, 0}); + int64 rhs_padding = pad_literal.Get({dimno, 1}); int64 dim_size = original_shape.dimensions(dimno); // Padding amounts on each side must be no more than the size of the @@ -65,8 +63,8 @@ class MirrorPadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); @@ -81,23 +79,19 @@ class MirrorPadOp : public XlaOpKernel { TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::XlaBuilder* b = ctx->builder(); - auto in0 = ctx->Input(0); + auto in0 = ctx->Input("input"); xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = @@ -112,7 +106,7 @@ class MirrorPadOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp); }; -REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"), +REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstantInput("paddings"), MirrorPadOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index cac2eea96eeed723b2a63bc9193070cad04b005d..aba54578d97c1e455f67efa2877ddc25dab68ac0 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); +REGISTER_XLA_OP(Name("OneHot").CompileTimeConstantInput("depth"), OneHotOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index e5937b56c17d01892928b073da09f38941ea1bbb..36ea70ac392ff18fb52d400efa886533f8335eba 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -29,40 +30,36 @@ class PadOp : public XlaOpKernel { explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape pad_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape pad_shape = ctx->InputShape("paddings"); const int dims = input_shape.dims(); OP_REQUIRES( ctx, TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2, errors::InvalidArgument("paddings must be a matrix with 2 columns: ", pad_shape.DebugString())); - const int fixed_dims = - (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1) - ? 1 - : dims; OP_REQUIRES( - ctx, fixed_dims == pad_shape.dim_size(0), + ctx, dims == pad_shape.dim_size(0), errors::InvalidArgument( "The first dimension of paddings must be the rank of inputs", pad_shape.DebugString(), " ", input_shape.DebugString())); - if (fixed_dims == 0) { + xla::XlaOp input = ctx->Input("input"); + if (dims == 0) { // Tensor is rank 0. Return it unchanged. - ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(0, input); return; } - // Evaluate the 'padding' constant input, reshaping to a matrix. xla::Literal pad_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); + OP_REQUIRES_OK(ctx, + ctx->ConstantInputAsInt64Literal("paddings", &pad_literal)); xla::PaddingConfig config; - for (int i = 0; i < fixed_dims; ++i) { + for (int i = 0; i < dims; ++i) { auto* dim = config.add_dimensions(); - int before = pad_literal.Get({i, 0}); - int after = pad_literal.Get({i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument( "Paddings must be non-negative: ", before, " ", after)); @@ -73,18 +70,19 @@ class PadOp : public XlaOpKernel { // PadV2 added a "constant_values" input that indicates the pad value. xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), - errors::InvalidArgument("constant_values must be a scalar.")); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(ctx->InputShape("constant_values")), + errors::InvalidArgument("constant_values must be a scalar.")); + ctx->SetOutput(0, xla::Pad(input, ctx->Input("constant_values"), config)); } else { auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); - ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config)); + ctx->SetOutput(0, xla::Pad(input, zero, config)); } } }; -REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp); -REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("Pad").CompileTimeConstantInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("PadV2").CompileTimeConstantInput("paddings"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc index 0764e5503db583351e92a144b2c361e8875161d3..71920bf5c1e6aa5981aafa8b611cc01c0917e02b 100644 --- a/tensorflow/compiler/tf2xla/kernels/permute_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -75,10 +75,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel { } auto keys = xla::ConstantR1(builder, absl::Span(dst_indices)); if (input_rank == 2) { - keys = xla::BroadcastInDim( - keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + keys = xla::BroadcastInDim(keys, {4, 2}, {0}); } - auto sorted = xla::Sort(keys, ctx->Input(0), 0); + auto sorted = xla::Sort(keys, {ctx->Input(0)}, 0); auto output = xla::GetTupleElement(sorted, 1); ctx->SetOutput(0, output); } @@ -90,9 +89,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); }; -// TODO(b/115384656): Support DT_INT64. -REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32), - DataFormatVecPermuteOp); +REGISTER_XLA_OP( + Name("DataFormatVecPermute").TypeConstraint("T", {DT_INT32, DT_INT64}), + DataFormatVecPermuteOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 27690c156e4da129ad139f3880bba3a208b5606d..06c6cc37ec90192486ba15010bfeb763a9ffb987 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -152,7 +152,12 @@ class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, - /*reduction_type=*/ctx->input_type(0)) {} + /*reduction_type=*/ctx->input_type(0)) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -180,16 +185,12 @@ class MaxPool2DOp : public MaxPoolOp { public: explicit MaxPool2DOp(OpKernelConstruction* ctx) : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); REGISTER_XLA_OP(Name("MaxPoolV2") - .CompileTimeConstInput("ksize") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("ksize") + .CompileTimeConstantInput("strides"), MaxPool2DOp); class MaxPool3DOp : public MaxPoolOp { @@ -204,7 +205,12 @@ class AvgPoolOp : public PoolingOp { AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ - XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + XlaHelpers::SumAccumulationType(ctx->input_type(0))) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } void Compile(XlaOpKernelContext* ctx) override { auto ksize_or_error = GetKernelSize(ctx); @@ -241,10 +247,6 @@ class AvgPool2DOp : public AvgPoolOp { public: explicit AvgPool2DOp(OpKernelConstruction* ctx) : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { - string data_format_str; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); - OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); @@ -360,8 +362,8 @@ class MaxPool2DGradOp : public MaxPoolGradOp { }; REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); REGISTER_XLA_OP(Name("MaxPoolGradV2") - .CompileTimeConstInput("ksize") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("ksize") + .CompileTimeConstantInput("strides"), MaxPool2DGradOp); class MaxPool3DGradOp : public MaxPoolGradOp { @@ -390,6 +392,11 @@ class AvgPoolGradOp : public XlaOpKernel { OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); + + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); } int num_dims() const { return num_spatial_dims_ + 2; } @@ -449,22 +456,20 @@ class AvgPool2DGradOp : public AvgPoolGradOp { public: explicit AvgPool2DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); } }; -REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), - AvgPool2DGradOp); +REGISTER_XLA_OP( + Name("AvgPoolGrad").CompileTimeConstantInput("orig_input_shape"), + AvgPool2DGradOp); class AvgPool3DGradOp : public AvgPoolGradOp { public: explicit AvgPool3DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} }; -REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), - AvgPool3DGradOp); +REGISTER_XLA_OP( + Name("AvgPool3DGrad").CompileTimeConstantInput("orig_input_shape"), + AvgPool3DGradOp); class MaxPoolGradGradOp : public XlaOpKernel { public: @@ -632,8 +637,8 @@ REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), MaxPool2DGradGradOp); REGISTER_XLA_OP(Name("MaxPoolGradGradV2") .TypeConstraint("T", DT_FLOAT) - .CompileTimeConstInput("ksize") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("ksize") + .CompileTimeConstantInput("strides"), MaxPool2DGradGradOp); class MaxPool3DGradGradOp : public MaxPoolGradGradOp { diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 6f4ed496a1774dde68dd9d5fbd37995d615b678c..7fe102428db1cc5ce16037f56fa301d1941da8e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" @@ -26,12 +27,26 @@ limitations under the License. namespace tensorflow { namespace { +enum QuantizerRoundMode { + // Round half up: if the fraction of y is exactly 0.5, then + // round(y) = y + 0.5 + // E.g., -5.5 gets rounded to -5, -5.4 goes to -5, + // 5.4 goes to 5, and 5.5 goes to 6. + ROUND_HALF_UP, + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; + class QuantizeAndDequantizeOp : public XlaOpKernel { public: explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); + round_mode_ = ROUND_HALF_TO_EVEN; } void Compile(XlaOpKernelContext* ctx) override { @@ -117,8 +132,17 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // in that case they were measured from the tensor. input = Clamp(min_range, input, max_range); } - xla::XlaOp result = - Floor((input - min_range) * scale + half) * inverse_scale + min_range; + xla::XlaOp result; + switch (round_mode_) { + case ROUND_HALF_TO_EVEN: { + result = xla::RoundToEven(input * scale) * inverse_scale; + break; + } + case ROUND_HALF_UP: { + result = Floor(input * scale + half) * inverse_scale; + break; + } + } ctx->SetOutput(0, result); } @@ -126,6 +150,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { int64 num_bits_ = -1; bool signed_input_; bool range_given_; + QuantizerRoundMode round_mode_; }; class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { @@ -136,6 +161,20 @@ class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), errors::InvalidArgument("num_bits is out of range: ", num_bits_, " with signed_input_ ", signed_input_)); + string round_mode_string; + OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); + OP_REQUIRES( + ctx, + (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"), + errors::InvalidArgument("Round mode string must be " + "'HALF_UP' or " + "'HALF_TO_EVEN', is '" + + round_mode_string + "'")); + if (round_mode_string == "HALF_UP") { + round_mode_ = ROUND_HALF_UP; + } else if (round_mode_string == "HALF_TO_EVEN") { + round_mode_ = ROUND_HALF_TO_EVEN; + } } }; diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index afd5986846705f66eb4c7ced9dbe2f4757f5af7f..8822e29f7e77b1cbc6fa6ca61d0062d9b1b0c36e 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -58,7 +57,7 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstantInput("shape"), RandomUniformOp); class RandomShuffleOp : public XlaOpKernel { @@ -135,7 +134,7 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp curr = input; for (int i = 0; i < rounds; ++i) { xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape); - xla::XlaOp sorted = xla::Sort(keys, curr); + xla::XlaOp sorted = xla::Sort(keys, {curr}); curr = xla::GetTupleElement(sorted, 1); } @@ -227,7 +226,7 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstantInput("shape"), RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { @@ -256,7 +255,7 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstantInput("shape"), RandomStandardNormalOp); class TruncatedNormalOp : public XlaOpKernel { @@ -282,7 +281,7 @@ class TruncatedNormalOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("TruncatedNormal") - .CompileTimeConstInput("shape") + .CompileTimeConstantInput("shape") .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28db71075fb8da269c55edbdb667193e..dacdbc88e005304bc64ea35c8985711afca41eae 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector window_dimensions; std::vector window_strides; + std::vector base_dilations; + std::vector window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -113,9 +130,11 @@ class ReduceWindowOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("XlaReduceWindow") - .CompileTimeConstInput("window_dimensions") - .CompileTimeConstInput("window_strides") - .CompileTimeConstInput("padding"), + .CompileTimeConstantInput("window_dimensions") + .CompileTimeConstantInput("window_strides") + .CompileTimeConstantInput("base_dilations") + .CompileTimeConstantInput("window_dilations") + .CompileTimeConstantInput("padding"), ReduceWindowOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 0d260fa8fcaa513d7854c1e9215952404d555c70..65e158d64fdd7df62d50b81c9e488b2d03476fb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -41,7 +41,8 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); +REGISTER_XLA_OP(Name("Sum").CompileTimeConstantInput("reduction_indices"), + SumOp); class ProdOp : public XlaReductionOp { public: @@ -59,7 +60,7 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), +REGISTER_XLA_OP(Name("Prod").CompileTimeConstantInput("reduction_indices"), ProdOp); class MinOp : public XlaReductionOp { @@ -77,7 +78,8 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); +REGISTER_XLA_OP(Name("Min").CompileTimeConstantInput("reduction_indices"), + MinOp); class MaxOp : public XlaReductionOp { public: @@ -94,7 +96,8 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); +REGISTER_XLA_OP(Name("Max").CompileTimeConstantInput("reduction_indices"), + MaxOp); class MeanOp : public XlaReductionOp { public: @@ -110,16 +113,25 @@ class MeanOp : public XlaReductionOp { xla::Add(scalar_lhs, scalar_rhs); } - xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) override { - auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), - num_elements_reduced); - return reduce_output / divisor; + xla::XlaOp BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce) override { + if (dimensions_to_reduce.empty()) { + return reduce_output; + } + auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); + for (int i = 1; i < dimensions_to_reduce.size(); i++) { + auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); + divisor = xla::Mul(divisor, size); + } + divisor = xla::ConvertElementType(divisor, xla_reduction_type_); + return XlaHelpers::ConvertElementType(reduce_output / divisor, + input_type(0)); } }; -REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), +REGISTER_XLA_OP(Name("Mean").CompileTimeConstantInput("reduction_indices"), MeanOp); class AllOp : public XlaReductionOp { @@ -137,7 +149,8 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); +REGISTER_XLA_OP(Name("All").CompileTimeConstantInput("reduction_indices"), + AllOp); class AnyOp : public XlaReductionOp { public: @@ -154,7 +167,8 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); +REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"), + AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 466e79828d111ee7cadcf713703e8f252c63e62c..af716eab79886791e8507a84984b7ca60865d00e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -48,13 +48,14 @@ class XlaReductionOp : public XlaOpKernel { const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired - // computation should be added to 'builder'. Argument 'reduce_output' is the - // output of the reduction. 'num_elements_reduced' is the number of elements - // that contributed to the reduction. Returns the transformed reduction - // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced); + // computation should be added to 'builder'. Argument 'input' is the original + // input of the reduction; 'reduce_output' is the output of the reduction. + // Returns the transformed reduction output. Defaults to returning + // 'reduce_output' converted to the input type. + virtual xla::XlaOp BuildFinalizer( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 118f2798d559f43acb7f6394a7337426164325ef..2ca2a85244b4edfe75db3d4fff6c2058adc2bf71 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,12 +35,13 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); } -// Unless BuildFinalizer is overridden the reduction has no -// finalizer. -xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, - const xla::XlaOp& reduce_output, - int64 num_elements_reduced) { - return reduce_output; +// The default finalizer converts the results back into the input type. This can +// be overridden. +xla::XlaOp XlaReductionOp::BuildFinalizer( + xla::XlaBuilder* /*builder*/, const xla::XlaOp& /*input*/, + const xla::XlaOp& reduce_output, + const std::vector& /*dimensions_to_reduce*/) { + return XlaHelpers::ConvertElementType(reduce_output, input_type(0)); } void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { @@ -71,7 +72,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; - int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { int64 index = axes[i]; OP_REQUIRES(ctx, @@ -82,7 +82,6 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { index = (index + data_shape.dims()) % data_shape.dims(); bitmap[index] = true; xla_axes.push_back(index); - num_elements_reduced *= data_shape.dim_size(index); } std::vector final_shape; @@ -118,8 +117,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes); - auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); - auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto finalized = BuildFinalizer(b, data, reduce, xla_axes); auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index d35777ccb1271ec6a7c9972c714d06b2415d9c34..a8e230ba107ce8a73f3e80f0e55fa27eea31338f 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -15,14 +15,12 @@ limitations under the License. // Native XLA implementations of XLA Relu Ops -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/no_op.h" namespace tensorflow { namespace { @@ -37,6 +35,7 @@ class ReluOp : public XlaOpKernel { ctx->SetOutput(0, xla::Max(zero, ctx->Input(0))); } }; +REGISTER_XLA_OP(Name("Relu"), ReluOp); class Relu6Op : public XlaOpKernel { public: @@ -49,6 +48,22 @@ class Relu6Op : public XlaOpKernel { ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six)); } }; +REGISTER_XLA_OP(Name("Relu6"), Relu6Op); + +class LeakyReluOp : public XlaOpKernel { + public: + explicit LeakyReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto features = ctx->Input("features"); + auto output = + xla::Max(features, features * xla::ScalarLike(features, alpha_)); + ctx->SetOutput(0, output); + } + float alpha_; +}; +REGISTER_XLA_OP(Name("LeakyRelu"), LeakyReluOp); class ReluGradOp : public XlaOpKernel { public: @@ -64,6 +79,7 @@ class ReluGradOp : public XlaOpKernel { ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero)); } }; +REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); class Relu6GradOp : public XlaOpKernel { public: @@ -83,11 +99,24 @@ class Relu6GradOp : public XlaOpKernel { ctx->SetOutput(0, out); } }; - -REGISTER_XLA_OP(Name("Relu"), ReluOp); -REGISTER_XLA_OP(Name("Relu6"), Relu6Op); -REGISTER_XLA_OP(Name("ReluGrad"), ReluGradOp); REGISTER_XLA_OP(Name("Relu6Grad"), Relu6GradOp); +class LeakyReluGradOp : public XlaOpKernel { + public: + explicit LeakyReluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); + } + void Compile(XlaOpKernelContext* ctx) override { + auto gradients = ctx->Input("gradients"); + auto features = ctx->Input("features"); + auto output = + xla::Select(xla::Gt(features, xla::ScalarLike(features, 0)), gradients, + gradients * xla::ScalarLike(gradients, alpha_)); + ctx->SetOutput(0, output); + } + float alpha_; +}; +REGISTER_XLA_OP(Name("LeakyReluGrad"), LeakyReluGradOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..54d34a38abc4948a1a08197d72e3e7f763649093 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -0,0 +1,576 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +using xla::XlaOp; + +// Calculates the bilinear weight tensor, given basis ratio (px, py) of the +// sampling position: +// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] +// 'ratio' tensor has dimensions [batch, dim_0, ...dim_n, 2]. +// +// The returned tensor has dimensions [batch, dim_0, ... dim_n, 4]. +XlaOp BilinearWeights(XlaOpKernelContext* ctx, XlaOp ratio, + const TensorShape warp_shape, + xla::PrimitiveType xla_type) { + auto first_term = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}, {0.0, 0.0}}); + first_term = xla::ConvertElementType(first_term, xla_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector broadcast_dims(warp_dims.begin(), warp_dims.end() - 1); + broadcast_dims.push_back(4); + broadcast_dims.push_back(2); + + const int64 broadcast_dims_size = broadcast_dims.size(); + + std::vector last_two_dims_indices = {(broadcast_dims_size - 2), + (broadcast_dims_size - 1)}; + + auto broadcast_first_term = + xla::BroadcastInDim(first_term, broadcast_dims, last_two_dims_indices); + + // Ratio is of the same dimension as warp, which is [batch, dim_0,... dim_n, + // 2], we broadcast ratio tensor to 'broadcast_dim' by keeping the + // [batch, dim_0,...dim_n] dimensions and the [2] dimension as the last + // dimension. + std::vector ratio_broadcast_indices(broadcast_dims.size()); + std::iota(ratio_broadcast_indices.begin(), ratio_broadcast_indices.end(), 0); + ratio_broadcast_indices.erase(ratio_broadcast_indices.end() - 2); + + auto broadcast_ratio = + xla::BroadcastInDim(ratio, broadcast_dims, ratio_broadcast_indices); + + auto first_term_subtract_weights = broadcast_first_term - broadcast_ratio; + + // Now we have [(1-px, 1-py), (-px, 1-py), (1-px, -py), (px, py)], need to + // flip the signs of the second and the third term. + auto sign_change = xla::ConstantR2( + ctx->builder(), {{1.0, 1.0}, {-1.0, 1.0}, {-1.0, 1.0}, {1.0, 1.0}}); + sign_change = xla::ConvertElementType(sign_change, xla_type); + + auto broadcast_sign_change = + xla::BroadcastInDim(sign_change, broadcast_dims, last_two_dims_indices); + + auto flipped = first_term_subtract_weights * broadcast_sign_change; + + // Build up the final bilinear weight tensor by multiply reduction, which + // gives: + // [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py] + // for each 4 neighboring pixels where px and py are the weight of the target + // pixel we are sampling from. + return xla::Reduce( + flipped, xla::One(ctx->builder(), xla_type), + xla::CreateScalarMultiplyComputation(xla_type, ctx->builder()), + {broadcast_dims_size - 1}); +} + +// Concatenates the batch indices to the (x, y) coordinate indices. +// This is done by first creating an Iota tensor that represents the current +// batch it is in, then concatenate with the givin (coordinate) indices. +// +// The resulting tensor has dimension (batch, dim_0, ... dim_n, 3) where +// the last dimension of size 3 in turn is [batch_number, x, y]. +// The [batch_number, x, y] dimension is needed because the indices +// [x,y] alone cannot allow the xla::Gather operation to gather from the input +// data, which is of dimension [batch, height(y), width(x), channel] with +// 'batch' being the first dimension. +XlaOp ConcatenateIota(xla::XlaBuilder* b, XlaOp indices, + const TensorShape& warp_shape) { + // We need to create an iota tensor with the same batch dimension. + std::vector dimensions; + for (auto dim : warp_shape) { + dimensions.push_back(dim.size); + } + // Except the last dimension, which is of size 1. + dimensions.back() = 1; + + auto batch_indices = + xla::Iota(b, xla::ShapeUtil::MakeShape(xla::U32, dimensions), + /*iota_dimension=*/0); + + return xla::ConcatInDim(b, {batch_indices, indices}, dimensions.size() - 1); +} + +// Gathers the 2x2 neighbors of the input starting_indices, and return a +// tensor of dimension [batch, dim_0, ... dim_n, 4, data_channels]. +// 'gather_indices' is of dimension [batch, dim_0, ..., dim_n, 3] where the last +// dimension of size 3 is (batch_no, x, y). +XlaOp Gather2by2Neighbors(xla::XlaBuilder* b, XlaOp data, XlaOp gather_indices, + int64 data_channels, int warp_dims) { + xla::GatherDimensionNumbers gather_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Gather output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the offset dimensions for Gather is the last 3 dimensions. + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 3); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 2); + gather_dim_numbers.add_offset_dims(neighbor_data_dimensions - 1); + // The last dimension of 'gather_indices' is the starting indices for gather. + gather_dim_numbers.set_index_vector_dim(warp_dims - 1); + gather_dim_numbers.add_collapsed_slice_dims(0); + gather_dim_numbers.add_start_index_map(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when gathering. + gather_dim_numbers.add_start_index_map(2); + gather_dim_numbers.add_start_index_map(1); + // Data dimensions are [batch, x, y, channel]. + // Output dimensions are [batch, dim_0, ... dim_n, 2, 2, data_channels]. + auto neighbors_data = xla::Gather(data, gather_indices, gather_dim_numbers, + /*slice_sizes=*/{1, 2, 2, data_channels}); + // Collapse the ...,2,2,... dimensions into ...,4,... + return xla::Collapse(neighbors_data, {warp_dims - 1, warp_dims}); +} + +// Scatter 'updates' tensor to 'grad_data' based on 'indices'. Returns the +// resulting tensor of dimension: [batch, dim_0, ...dim_n, 2, 2, data_channels]. +// This function can also be seen as the inverse of 'Gather2by2Neighbors'. +XlaOp ScatterToGradData(XlaOpKernelContext* ctx, XlaOp grad_data, XlaOp indices, + XlaOp updates, int64 warp_dims, + xla::PrimitiveType xla_type) { + xla::ScatterDimensionNumbers scatter_dim_numbers; + const int64 neighbor_data_dimensions = warp_dims + 2; + // Since the Scatter output dimensions are [batch, dim_0, ... dim_n, 2, 2, + // data_channels], the update window dimensions is the last 3 dimensions. + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 3); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 2); + scatter_dim_numbers.add_update_window_dims(neighbor_data_dimensions - 1); + scatter_dim_numbers.set_index_vector_dim(warp_dims - 1); + + scatter_dim_numbers.add_inserted_window_dims(0); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(0); + // Since input is of dimension [batch, height(y), width(x), channel], and warp + // is of dimension [batch, x, y], the ordering of x, y here needs to be + // swapped when scattering. + scatter_dim_numbers.add_scatter_dims_to_operand_dims(2); + scatter_dim_numbers.add_scatter_dims_to_operand_dims(1); + + return xla::Scatter(grad_data, indices, updates, + xla::CreateScalarAddComputation(xla_type, ctx->builder()), + scatter_dim_numbers); +} + +// Build computation the backprop into input 'data'. +// Where input: +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// +// Output: +// scatter-add to each 2x2 grad_data neighbor: +// grad_data[fx, fy, chan] += output_grad * dx * dy +// grad_data[cx, fy, chan] += output_grad * (1 - dx) * dy +// grad_data[fx, cy, chan] += output_grad * dx * (1 - dy) +// grad_data[cx, cy, chan] += output_grad * (1 - dx) * (1 - dy) +// where (dx, dy) is (1 - ratio). +XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, xla::PrimitiveType warp_type, + TensorShape warp_shape, int64 data_channels, + xla::Shape data_shape) { + // Weights tensor has dimension [batch, dim_0, ... dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, warp_type); + + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + std::vector reshaped_weights_dims = warp_dims_without_last_dims; + // Reshape the last dimension of size 4 to two dimensions [2, 2]. + reshaped_weights_dims.push_back(2); + reshaped_weights_dims.push_back(2); + std::vector reshape_dims(warp_shape.dims()); + std::iota(reshape_dims.begin(), reshape_dims.end(), 0); + // The dimension is [batch, dim_0,..., dim_n, 2, 2]. + auto reshaped_weights = xla::Reshape(weights, /*dimensions=*/reshape_dims, + /*new_sizes=*/reshaped_weights_dims); + + std::vector weights_with_channels_dims = reshaped_weights_dims; + weights_with_channels_dims.push_back(data_channels); + std::vector reshaped_weights_indices(reshaped_weights_dims.size()); + std::iota(reshaped_weights_indices.begin(), reshaped_weights_indices.end(), + 0); + + // The dimension is [batch, dim_0, ..., dim_n, 2, 2, data_channel]. + auto broadcast_reshaped_weights = xla::BroadcastInDim( + reshaped_weights, weights_with_channels_dims, reshaped_weights_indices); + + std::vector grad_output_indices(warp_dims_without_last_dims.size()); + std::iota(grad_output_indices.begin(), grad_output_indices.end(), 0); + grad_output_indices.push_back(weights_with_channels_dims.size() - 1); + XlaOp broadcast_grad_output = xla::BroadcastInDim( + grad_output, weights_with_channels_dims, grad_output_indices); + + auto grad_output_multiply_weights = + broadcast_grad_output * broadcast_reshaped_weights; + + auto grad_data = xla::ConstantLiteral( + ctx->builder(), xla::Literal::CreateFromShape(data_shape)); + + return ScatterToGradData(ctx, grad_data, gather_indices, + grad_output_multiply_weights, warp_shape.dims(), + warp_type); +} + +// Build computation for the backprop into input 'warp'. +// Where input: +// warp is of dimension [batch, dim_0, ...dim_n, 2] +// grad_output is of dimension [batch, dim_0, ...dim_n, channel] +// ratio is of dimension [batch, dim_0, ...dim_n, 2] +// gather_indices is of dimension [batch, dim_0, ...dim_n, 3] +// data is of dimension [batch, x, y, channel] +// +// Output (simplified by ignoring the batch dimensions): +// Since the forward path has: +// output = dot(weights * neighbors) +// The backprop into warp will therefore be: +// grad_warp = output_grad * d_output / d_warp +// = output_grad * (d_weights / d_warp * neighbors + d_neighbors / +// d_warp * weight) +// Where: +// d_weights / d_warp_x = [-(1 - py), (1 - py), -py, py] +// d_weights / d_warp_y = [-(1 - px), -px, (1-px), px] +// and +// d_neighbors / d_warp_x = 0 +// +// Therefore: +// grad_warp_x = py * (img_cxcy - img_fxcy) + (1-py) * (img_cxfy-img_fxfy) +// grad_warp_y = px * (img_cxcy - img_cxfy) + (1-px) * (img_fxcy-img_fxfy) +// +// where (px, py) is warp, (fx, fy) is the left top corner and (cx, cy) is the +// bottom right corner in a 2x2 neighborhood. +XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, + XlaOp gather_indices, XlaOp data, + TensorShape warp_shape, int64 data_channels, + xla::PrimitiveType data_type) { + auto warp_dims = warp_shape.dim_sizes(); + std::vector warp_dims_without_last_dims(warp_dims.begin(), + warp_dims.end() - 1); + + // With dimension [batch, dim_0, ...dim_n, 4] + std::vector neighbor_broadcast_dims = warp_dims_without_last_dims; + neighbor_broadcast_dims.push_back(4); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + const int64 last_warp_dim = warp_shape.dims() - 1; + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + // img_cxcy - img_fxcy + auto bottom_right_minus_bottom_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, 0, -1, 1}), data_type), + neighbor_broadcast_dims, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxfy - img_fxfy + auto top_right_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 1, 0, 0}), data_type), + neighbor_broadcast_dims, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_cxcy - img_cxfy + auto bottom_right_minus_top_right = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {0, -1, 0, 1}), data_type), + neighbor_broadcast_dims, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // img_fxcy - img_fxfy + auto bottom_left_minus_top_left = xla::DotGeneral( + xla::BroadcastInDim( + xla::ConvertElementType( + xla::ConstantR1(ctx->builder(), {-1, 0, 1, 0}), data_type), + neighbor_broadcast_dims, {last_warp_dim}), + neighbors_data, dot_dims, /*precision_config=*/nullptr); + + // Slice out x and y. + auto weight_x = xla::SliceInDim(ratio, /*start_index=*/0, /*limit_index=*/1, + /*stride=*/1, /*dimno=*/last_warp_dim); + auto weight_y = xla::SliceInDim(ratio, /*start_index=*/1, /*limit_index=*/2, + /*stride=*/1, /*dimno=*/last_warp_dim); + + // Build 1 - y and 1 - x. + auto one_minus_y = xla::One(ctx->builder(), data_type) - weight_y; + auto one_minus_x = xla::One(ctx->builder(), data_type) - weight_x; + + auto x_before_reduce = + grad_output * weight_y * bottom_right_minus_bottom_left + + one_minus_y * top_right_minus_top_left; + + std::vector reshaped_sizes = warp_dims_without_last_dims; + reshaped_sizes.push_back(1); + + std::vector reshaped_dims(warp_dims_without_last_dims.size()); + std::iota(reshaped_dims.begin(), reshaped_dims.end(), 0); + + // Reduce-add along the channel dimension. + auto x_result = + xla::Reduce(x_before_reduce, xla::Zero(ctx->builder(), data_type), + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + // Reshape before concatenating with y values. + XlaOp reshaped_x = xla::Reshape(x_result, reshaped_dims, reshaped_sizes); + + auto y_before_reduce = grad_output * weight_x * bottom_right_minus_top_right + + one_minus_x * bottom_left_minus_top_left; + // Reduce-add along the channel dimension. + auto y_result = + xla::Reduce(y_before_reduce, xla::Zero(ctx->builder(), data_type), + + xla::CreateScalarAddComputation(data_type, ctx->builder()), + {last_warp_dim}); + XlaOp reshaped_y = xla::Reshape(y_result, reshaped_dims, reshaped_sizes); + + return xla::ConcatInDim(ctx->builder(), {reshaped_x, reshaped_y}, + last_warp_dim); +} + +class ResamplerOp : public XlaOpKernel { + public: + explicit ResamplerOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape.DebugString())); + const int64 data_channels = data_shape.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + const int64 last_warp_dim = warp_shape.dims() - 1; + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); + + XlaOp data = ctx->Input("data"); + XlaOp warp = ctx->Input("warp"); + + // Find the coordinates of the top left corner for the 2x2 region to be + // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the + // last dimension of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + // The dimension is [batch, dim_0, ... dim_n, 4, data_channels] + auto neighbors_data = Gather2by2Neighbors( + ctx->builder(), data, gather_indices, data_channels, warp_shape.dims()); + + // Dimensions are [batch, dim_0, ... dim_n, 2]. + XlaOp ratio = warp - xla::ConvertElementType(top_left, data_type); + + // Obtain the bilinear blending weights, the dimension is [batch, dim_0, + // ...dim_n, 4]. + auto weights = BilinearWeights(ctx, ratio, warp_shape, data_type); + + // Since we will be creating the dot product of: + // lhs: [batch, dim_0, ...dim_n, 4] + // and + // rhs: [batch, dim_0, ...dim_n, 4, data_channels] + // we choose the last dimension of lhs and the second last dimension of rhs, + // with size 4, as the contracting dimension. + xla::DotDimensionNumbers dot_dims; + for (int i = 0; i < warp_shape.dims() - 1; ++i) { + dot_dims.add_lhs_batch_dimensions(i); + dot_dims.add_rhs_batch_dimensions(i); + } + dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1); + dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1); + + // The dimension is [batch, dim_0, ...dim_n, data_channels]. + auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims, + /*precision_config=*/nullptr); + + // Handle out of boundary cases by constructing a predicate mask array based + // on the in-bound condition, and output 0 for the blended pixel value if + // out-bound. The dimension is the same as top_left: [batch, dim_0, + // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate. + + auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp)); + + auto is_lt_image_size = xla::Lt( + warp, + xla::ConvertElementType( + xla::ConstantR1( + ctx->builder(), + {/*width=*/static_cast(data_shape.dim_size(2) - 1), + /*height=*/static_cast(data_shape.dim_size(1) - 1)}), + warp_type), + /*broadcast_dimensions=*/{warp_shape.dims() - 1}); + + auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size); + // Reduce along last dimension. The resulting dimension is: + // [batch, dim_0, ...dim_n]. + auto is_in_bound = xla::Reduce( + is_in_bound_x_y, xla::ConstantR0(ctx->builder(), true), + xla::CreateScalarAndComputation(xla::PrimitiveType::PRED, + ctx->builder()), + {last_warp_dim}); + + // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which + // is the dimension of the result: + // [batch, dim_0, ...dim_n, data_channels]. + auto warp_dims = warp_shape.dim_sizes(); + std::vector result_dims(warp_dims.begin(), warp_dims.end() - 1); + result_dims.push_back(data_channels); + + std::vector broadcasted_dims(warp_dims.size() - 1); + std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0); + auto broadcasted_is_in_bound = + xla::BroadcastInDim(is_in_bound, result_dims, broadcasted_dims); + + // Set out of bound samples to zero. + auto zeros = + xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims); + auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros); + + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("Resampler"), ResamplerOp); + +class ResamplerGradOp : public XlaOpKernel { + public: + explicit ResamplerGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + } + + // TODO(b/112295522): note that sampling from image boundary is not currently + // being handled properly. + void Compile(XlaOpKernelContext* ctx) override { + TensorShape data_shape_tf = ctx->InputShape("data"); + OP_REQUIRES(ctx, data_shape_tf.dims() == 4, + errors::InvalidArgument("data must be 4-dimensional", + data_shape_tf.DebugString())); + const int64 data_channels = data_shape_tf.dim_size(3); + xla::PrimitiveType data_type = ctx->input_xla_type(0); + + TensorShape warp_shape = ctx->InputShape("warp"); + OP_REQUIRES(ctx, warp_shape.dims() >= 2, + errors::InvalidArgument("warp must be at least 2-dimensional", + warp_shape.DebugString())); + for (int size : warp_shape.dim_sizes()) { + OP_REQUIRES(ctx, size > 0, + errors::InvalidArgument("warp sizes must be positive, got [", + size, "]")); + } + // Last dimension of warp shape must be of size 2. + OP_REQUIRES(ctx, warp_shape.dim_size(warp_shape.dims() - 1) == 2, + errors::InvalidArgument( + "the last dimension of warp must be exactly size 2.")); + xla::PrimitiveType warp_type = ctx->input_xla_type(1); + + TensorShape output_grad_shape = ctx->InputShape("grad_output"); + OP_REQUIRES( + ctx, output_grad_shape.dims() >= 2, + errors::InvalidArgument("output_grad must be at least 2-dimensional", + output_grad_shape.DebugString())); + + // Dimensions are [batch, x, y, channel]. + XlaOp data = ctx->Input("data"); + xla::Shape data_shape = TensorShapeToXLAShape(data_type, data_shape_tf); + + // Dimensions are [batch, dim_0, ...dim_n, 2]. + XlaOp warp = ctx->Input("warp"); + // Dimensions are [batch, dim_0, ...dim_n, channel]. + XlaOp grad_output = ctx->Input("grad_output"); + + // Find the top left corner coordinate for the region to be sampled from. + // The dimensions are [batch, dim_0, ... dim_n, 2] where the last dimension + // of size 2 in turn is [x, y]. + XlaOp top_left = xla::ConvertElementType(warp, xla::U32); + + // Dimensions are [batch, dim_0, ... dim_n, 2] + XlaOp ratio = warp - xla::ConvertElementType(top_left, warp_type); + + // Indices for gathering neighboring pixels. + auto gather_indices = ConcatenateIota(ctx->builder(), top_left, warp_shape); + + auto grad_data = + CalculateGradData(ctx, grad_output, ratio, gather_indices, warp_type, + warp_shape, data_channels, data_shape); + + auto grad_warp = + CalculateGradWarp(ctx, grad_output, ratio, gather_indices, data, + warp_shape, data_channels, data_type); + + ctx->SetOutput(0, grad_data); + ctx->SetOutput(1, grad_warp); + } +}; + +REGISTER_XLA_OP(Name("ResamplerGrad"), ResamplerGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 366ce42866e9f1375ee0ff6f4985c8f461fc0885..fa1b6b91710f5507f41f3f69b0715398ae879aaf 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { namespace { @@ -36,7 +37,7 @@ class ReshapeOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); const TensorShape sizes_shape = ctx->InputShape(1); // Preliminary validation of sizes. - OP_REQUIRES(ctx, IsLegacyVector(sizes_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(sizes_shape), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); @@ -95,7 +96,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp); +REGISTER_XLA_OP(Name("Reshape").CompileTimeConstantInput("shape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index e172c649325adb6f7761ce0be141f21e8d545bc1..e4046c795577983bff1a8053743bf4d3a258e583 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -46,61 +47,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::XlaOp input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - DataType input_type = ctx->input_type(0); - XlaContext& tc = XlaContext::Get(ctx); - - if (input_type == DT_RESOURCE) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - ctx->SetStatus(tc.AddResourceRetval(index_, resource)); - return; - } - - auto is_constant = ctx->builder()->IsConstant(input); - if (!is_constant.ok()) { - ctx->SetStatus(is_constant.status()); - return; - } - - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); - } else { - TensorShape shape = ctx->InputShape(0); - ctx->SetStatus(is_constant.status()); - TensorShape representation_shape; - if (tc.is_entry_computation()) { - xla::StatusOr shape_or_status = - tc.RepresentationShape(shape, ctx->input_type(0)); - if (!shape_or_status.ok()) { - ctx->SetStatus(shape_or_status.status()); - return; - } else { - representation_shape = shape_or_status.ValueOrDie(); - } - } else { - representation_shape = shape; - } - - xla::XlaOp output = input; - if (tc.is_entry_computation()) { - output = xla::Reshape(input, representation_shape.dim_sizes()); - } else { - // The core from which a return value is returned depends on the - // device assignment of the input to the retval. Since we can't change - // the device assignment of "input" at this point, we must always - // introduce an operator here, even if the shape does not change. - // TODO(b/76097077): propagate device assignments onto arguments and - // return values of functions, and then reshape unconditionally. - output = - xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); - } - tc.AddRetval(index_, dtype_, shape, output); - } + ctx->xla_context()->SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 8494864b33a44b03a07e3fea7766285f54074e7d..2ceadaf79c5cef35ad50aa84a0d66a46527a6458 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -51,14 +51,11 @@ class ReverseOp : public XlaOpKernel { } // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); - std::vector revdims(x_shape.dims()); - std::copy(lax.data().begin(), lax.data().end(), - revdims.begin()); - std::vector dimensions; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &lax)); + std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { - if (revdims[d]) { + if (lax.Get({d})) { dimensions.push_back(d); } } @@ -67,7 +64,7 @@ class ReverseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp); +REGISTER_XLA_OP(Name("Reverse").CompileTimeConstantInput("dims"), ReverseOp); class ReverseV2Op : public XlaOpKernel { public: @@ -119,7 +116,8 @@ class ReverseV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op); +REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstantInput("axis"), + ReverseV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 03a50ef8a059e5a005c4cc2e5e98acedfea8619a..d7b38e86cc985d608116488f9e76756a8e904f9c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -17,8 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -61,113 +62,79 @@ class ReverseSequenceOp : public XlaOpKernel { const auto seq_lens = context->Input(1); const int64 batch_size = input_shape.dim_size(batch_dim_); + if (batch_size == 0) { + context->SetOutput(0, input); + return; + } - const DataType input_type = context->input_type(0); - const DataType seq_lens_type = context->input_type(1); + // Given the input + // + // 012345 + // 6789AB + // + // and sequence lens {2, 3} we: + // + // 1. Reverse and pad each row to get + // + // 543210XXXXXX + // BA9876XXXXXX + // + // 2. Gather out the suffix from each row to get + // + // 10XXXX + // 876XXX + // + // 3. Select from the input and the array created by (2) to get the result. + // + // 102345 + // 8769AB + const xla::PrimitiveType input_type = context->input_xla_type(0); + const xla::PrimitiveType seq_lens_type = context->input_xla_type(1); const int64 max_seq_len = input_shape.dim_size(seq_dim_); - xla::Shape input_xla_shape; - OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, - &input_xla_shape)); - xla::Shape seq_lens_xla_shape; - OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, - &seq_lens_xla_shape)); - - const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ - xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), - seq_lens_xla_shape, - input_xla_shape, - }); - - // For each entry in the batch, reverse the sequence. - // TODO(b/65689298): generalize the Map() operator to non-scalar cases and - // use it here, instead of a While loop. - - // Condition: lambda (i, _, _): i < batch_size - auto condition_builder = - builder->CreateSubBuilder("reverse_sequence_condition"); - { - auto param = - xla::Parameter(condition_builder.get(), 0, tuple_shape, "param"); - auto i = xla::GetTupleElement(param, 0); - xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(), - seq_lens_type, batch_size)); - } - auto condition = condition_builder->Build(); - OP_REQUIRES_OK(context, condition.status()); - - auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); - { - auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param"); - auto i = xla::GetTupleElement(param, 0); - auto seq_lens = xla::GetTupleElement(param, 1); - auto output = xla::GetTupleElement(param, 2); - - // seq_len is the sequence length of the current batch element (rank 1) - auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1}); - - // Indices is the offset of the batch element in the input. - auto batch_element_indices = - xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {input_shape.dims()}); - batch_element_indices = xla::DynamicUpdateSlice( - batch_element_indices, xla::Reshape(i, {1}), - xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), - seq_lens_type, batch_dim_), - {1})); - - // Slice out the current batch element and pad it out in the sequence - // dimension. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - slice_shape.set_dim(seq_dim_, max_seq_len); - auto slice = xla::DynamicSlice(output, batch_element_indices, - slice_shape.dim_sizes()); - auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); - padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( - slice_shape.dim_size(seq_dim_)); - slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type), - padding_config); - - // Now slice out the reversed sequence from its actual start. - // sequence_start_indices is the offset of the start of the reversed - // sequence in the input. The slice will go into the padding, however, we - // will mask off these elements and replace them with elements from the - // original input so their values do not matter. - auto sequence_start_indices = - xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type), - {slice_shape.dims()}); - sequence_start_indices = xla::DynamicUpdateSlice( - sequence_start_indices, - xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, - max_seq_len), - seq_len), - xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(), - seq_lens_type, seq_dim_), - {1})); - slice = xla::DynamicSlice(slice, sequence_start_indices, - slice_shape.dim_sizes()); - - // Shift the reversed sequence to the left. - output = xla::DynamicUpdateSlice(output, slice, batch_element_indices); - - xla::Tuple( - body_builder.get(), - {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)), - seq_lens, output}); + xla::XlaOp rev = xla::Rev(input, {seq_dim_}); + + auto padding_config = xla::MakeNoPaddingConfig(input_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + max_seq_len); + xla::XlaOp padded = + xla::Pad(rev, xla::Zero(builder, input_type), padding_config); + + // Form a start indices tensor with shape [2, batch_size]. For each batch + // entry we have a (batch offset, seq offset) pair. + xla::XlaOp start_indices = xla::ConcatInDim( + builder, + { + xla::Iota(builder, + xla::ShapeUtil::MakeShape(seq_lens_type, {1, batch_size}), + /*iota_dimension=*/1), + xla::Reshape(xla::ScalarLike(seq_lens, max_seq_len) - seq_lens, + {1, batch_size}), + }, + /*dimension=*/0); + + xla::GatherDimensionNumbers dnums; + // The first dimension of start_indices contains the batch/seq dim choice. + dnums.set_index_vector_dim(0); + dnums.add_start_index_map(batch_dim_); + dnums.add_start_index_map(seq_dim_); + + // All other dimensions other than the batch dim are offset dimensions. + for (int i = 0; i < input_shape.dims(); ++i) { + if (i != batch_dim_) { + dnums.add_offset_dims(i); + } } - auto body = body_builder->Build(); - OP_REQUIRES_OK(context, body.status()); - - auto loop_output = xla::While( - condition.ValueOrDie(), body.ValueOrDie(), - xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens, - xla::Rev(input, {seq_dim_})})); - auto output = xla::GetTupleElement(loop_output, 2); - - // Mask out elements after the sequence length. - xla::XlaOp iota = - xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len); + dnums.add_collapsed_slice_dims(batch_dim_); + + auto slice_sizes = input_shape.dim_sizes(); + slice_sizes[batch_dim_] = 1; + + xla::XlaOp output = xla::Gather(padded, start_indices, dnums, slice_sizes); + + // Mask out elements after the sequence length, and copy the corresponding + // elements from the input. + xla::XlaOp iota = xla::Iota(builder, seq_lens_type, max_seq_len); std::vector dims(input_shape.dims(), 1); dims[batch_dim_] = batch_size; auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_}); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1ce9856a3c2854fd2776827d6c4b76f..4b9e1a578be2445091228953df7e5c5e82b42c28 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -39,8 +39,8 @@ namespace { // TODO(phawkins): implement double-sized windowed reductions in XLA and remove // the type constraint. -constexpr std::array kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; class ScanOp : public XlaOpKernel { public: @@ -103,10 +103,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = xla::ReduceWindowWithGeneralPadding( - XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); - output = - XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); + XlaHelpers::ConvertElementType(ctx->Input(0), dtype), init, *reducer, + window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); + output = XlaHelpers::ConvertElementType(output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. @@ -135,7 +135,7 @@ class CumsumOp : public ScanOp { }; REGISTER_XLA_OP(Name("Cumsum") .TypeConstraint("T", kScanOpTypes) - .CompileTimeConstInput("axis"), + .CompileTimeConstantInput("axis"), CumsumOp); class CumprodOp : public ScanOp { @@ -144,7 +144,7 @@ class CumprodOp : public ScanOp { }; REGISTER_XLA_OP(Name("Cumprod") .TypeConstraint("T", kScanOpTypes) - .CompileTimeConstInput("axis"), + .CompileTimeConstantInput("axis"), CumprodOp); } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f1f32699fee5f03f603f830722fe65622dee5d3e..a95e7adacf194ba6eb33cbeb56abe1a5a2479337 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -116,7 +116,8 @@ class ScatterNdOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp); +REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"), + ScatterNdOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index b22ecb7c6dbb42a33a4f4d90b18b20816df16a50..97359f81eee4aa0b46f03941ab6ca3ea3d468f1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -105,7 +105,7 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce { }; REGISTER_XLA_OP( - Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), + Name("UnsortedSegmentSum").CompileTimeConstantInput("num_segments"), UnsortedSegmentSum); class UnsortedSegmentProd : public UnsortedSegmentReduce { @@ -120,7 +120,7 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce { }; REGISTER_XLA_OP( - Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"), + Name("UnsortedSegmentProd").CompileTimeConstantInput("num_segments"), UnsortedSegmentProd); class UnsortedSegmentMin : public UnsortedSegmentReduce { @@ -137,7 +137,7 @@ class UnsortedSegmentMin : public UnsortedSegmentReduce { }; REGISTER_XLA_OP( - Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"), + Name("UnsortedSegmentMin").CompileTimeConstantInput("num_segments"), UnsortedSegmentMin); class UnsortedSegmentMax : public UnsortedSegmentReduce { @@ -154,7 +154,7 @@ class UnsortedSegmentMax : public UnsortedSegmentReduce { }; REGISTER_XLA_OP( - Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"), + Name("UnsortedSegmentMax").CompileTimeConstantInput("num_segments"), UnsortedSegmentMax); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index a7f5a8f1698b9d02560de427d356e9e6be5caa7c..84470b230d421658e0d79dcecb175a24155f49b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -42,7 +42,7 @@ SendOp::SendOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void SendOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); xla::Send(ctx->Input(0), channel); @@ -73,7 +73,7 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void RecvOp::Compile(XlaOpKernelContext* ctx) { - XlaCompiler* compiler = XlaContext::Get(ctx).compiler(); + XlaCompiler* compiler = ctx->compiler(); xla::ChannelHandle channel; OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel)); ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 0c32b8def0f7b741c93e803f8359b6504087e257..b1fa2915d59e4e5e2f2523e20e9a37898d087117 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,31 +30,6 @@ limitations under the License. namespace tensorflow { namespace { -template -Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = literal.Get({}); - return Status::OK(); -} - -Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { - xla::Literal literal; - TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - switch (literal.shape().element_type()) { - case xla::S32: - *value = literal.Get({}); - break; - case xla::S64: - *value = literal.Get({}); - break; - default: - return errors::InvalidArgument("Invalid argument type for argument", - index); - } - return Status::OK(); -} - // The type-specific part of the implementation of Range. template xla::StatusOr CreateRangeTensor( @@ -98,13 +73,13 @@ class RangeOp : public XlaOpKernel { const TensorShape start_in_shape = ctx->InputShape(0); const TensorShape limit_in_shape = ctx->InputShape(1); const TensorShape delta_in_shape = ctx->InputShape(2); - OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape), errors::InvalidArgument("limit must be a scalar, not shape ", limit_in_shape.DebugString())); - OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape), errors::InvalidArgument("delta must be a scalar, not shape ", delta_in_shape.DebugString())); xla::Literal start, limit, delta; @@ -137,9 +112,9 @@ class RangeOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("Range") - .CompileTimeConstInput("start") - .CompileTimeConstInput("limit") - .CompileTimeConstInput("delta"), + .CompileTimeConstantInput("start") + .CompileTimeConstantInput("limit") + .CompileTimeConstantInput("delta"), RangeOp); class LinSpaceOp : public XlaOpKernel { @@ -147,9 +122,9 @@ class LinSpaceOp : public XlaOpKernel { explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape start_in_shape = ctx->InputShape(0); - const TensorShape stop_in_shape = ctx->InputShape(1); - const TensorShape num_in_shape = ctx->InputShape(2); + const TensorShape start_in_shape = ctx->InputShape("start"); + const TensorShape stop_in_shape = ctx->InputShape("stop"); + const TensorShape num_in_shape = ctx->InputShape("num"); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), errors::InvalidArgument("start must be a scalar, not shape ", start_in_shape.DebugString())); @@ -163,16 +138,20 @@ class LinSpaceOp : public XlaOpKernel { DataType type = ctx->input_type(0); int64 num; - OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num)); OP_REQUIRES(ctx, num > 0, errors::InvalidArgument("Requires num > 0: ", num)); Tensor out_constant(type, TensorShape({num})); + xla::Literal start_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal)); + xla::Literal stop_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal)); + switch (type) { case DT_FLOAT: { - float start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + float start = start_literal.GetFirstElement(); + float stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; @@ -185,9 +164,8 @@ class LinSpaceOp : public XlaOpKernel { break; } case DT_DOUBLE: { - double start, stop; - OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); - OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); + double start = start_literal.GetFirstElement(); + double stop = stop_literal.GetFirstElement(); auto flat = out_constant.flat(); if (num == 1) { flat(0) = start; @@ -210,9 +188,9 @@ class LinSpaceOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("LinSpace") - .CompileTimeConstInput("start") - .CompileTimeConstInput("stop") - .CompileTimeConstInput("num"), + .CompileTimeConstantInput("start") + .CompileTimeConstantInput("stop") + .CompileTimeConstantInput("num"), LinSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index c8a0f31a0375abacaca26688a23f4835e11c692e..12830816ec16c9797f0fe4d8f3f13f5a8176161d 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { @@ -108,21 +109,16 @@ class ExpandDimsOp : public XlaOpKernel { explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dim_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape dim_shape = ctx->InputShape("dim"); - // TODO(phawkins): the standard implementation of ExpandDimsOp seems to - // accept legacy scalars, even when they should be forbidden by the graphdef - // version. - OP_REQUIRES(ctx, dim_shape.num_elements() == 1, + std::vector dims; + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); + OP_REQUIRES(ctx, dims.size() == 1, errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); - - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - - int dim = literal.data()[0]; + int dim = dims[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -148,10 +144,11 @@ class ExpandDimsOp : public XlaOpKernel { dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); - ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); + ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); } }; -REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), + ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 537b71f3c0cf3622a8a45a717ac406da69f5c3c7..88da64e5a217a0c026106f03cb26958f6738446c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" @@ -42,8 +43,8 @@ class SliceOp : public XlaOpKernel { OP_REQUIRES( ctx, - IsLegacyVector(begin_tensor_shape) && - IsLegacyVector(size_tensor_shape) && + TensorShapeUtils::IsVector(begin_tensor_shape) && + TensorShapeUtils::IsVector(size_tensor_shape) && begin_tensor_shape.num_elements() == input_shape.dims() && size_tensor_shape.num_elements() == input_shape.dims(), errors::InvalidArgument( @@ -111,9 +112,10 @@ class SliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP( - Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"), - SliceOp); +REGISTER_XLA_OP(Name("Slice") + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("size"), + SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index d6bd927135c013ac1ec3f6547aef358dc2741896..20da8033536e3af3da0fcb216db45f808cacc1d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -71,7 +71,7 @@ class SoftmaxOp : public XlaOpKernel { auto reduce = xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum = XlaHelpers::ConvertElementType(reduce, type); auto softmax = log_ // softmax = shifted_logits - log(sum(exp(shifted_logits))) @@ -111,11 +111,11 @@ std::pair CrossEntropyWithLogits( // sum_{class} (exp(logits - max_logits)) const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type); auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); + auto sum_exp = XlaHelpers::ConvertElementType(reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = xla::Log(sum_exp); @@ -126,11 +126,10 @@ std::pair CrossEntropyWithLogits( // (The subtraction broadcasts along the batch dimension.) auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); auto mul = xla::Mul(xla::Neg(labels), sub); - auto sum = - xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); - auto loss = XlaHelpers::ConvertElementType(b, sum, type); + auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01ccb303091a6d37d1aeb4b2a3377dc638..6cfdf4a5ae479e9851454df97160754f122bc6ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), {context->Input("values")}); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 76b79be6f6f6b5ecbe9edcffb81f2834fdac9a56..622efac81766fc3ddaf538b58170f34fce06927a 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -161,8 +161,8 @@ class SpaceToBatchNDOp : public XlaOpKernel { } }; REGISTER_XLA_OP(Name("SpaceToBatchND") - .CompileTimeConstInput("paddings") - .CompileTimeConstInput("block_shape"), + .CompileTimeConstantInput("paddings") + .CompileTimeConstantInput("block_shape"), SpaceToBatchNDOp); class SpaceToBatchOp : public XlaOpKernel { @@ -185,7 +185,7 @@ class SpaceToBatchOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), +REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstantInput("paddings"), SpaceToBatchOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index e831dc30a9d3c27ec3b1494e7d8a6de836ff2a11..def3c147bf3fc619784044357e95bf32b404954b 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -80,7 +80,7 @@ class SparseToDenseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstInput("output_shape"), +REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstantInput("output_shape"), SparseToDenseOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 93fc14e9efca868e84444dd0e07d7f0dfa84c042..7a0e240400b344ab25743997ce3baad81bd5f476 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -35,26 +35,16 @@ class SplitOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); - const TensorShape index_shape = ctx->InputShape(0); + const TensorShape split_dim_shape = ctx->InputShape("split_dim"); const TensorShape input_shape = ctx->InputShape(1); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - - int32 split_dim_orig; - if (index_shape.dims() == 0) { - split_dim_orig = literal_index.Get({}); - } else { - OP_REQUIRES( - ctx, index_shape.dims() == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - OP_REQUIRES( - ctx, index_shape.dim_size(0) == 1, - errors::InvalidArgument("split_index input to Split Op must be a " - "scalar or a vector with 1 element")); - split_dim_orig = literal_index.Get({0}); - } + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(split_dim_shape), + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_shape.dims())); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() : split_dim_orig; OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), @@ -104,7 +94,7 @@ class SplitOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); +REGISTER_XLA_OP(Name("Split").CompileTimeConstantInput("split_dim"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -138,7 +128,6 @@ class SplitVOp : public XlaOpKernel { // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; - std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); OP_REQUIRES(ctx, split_size_shape.dims() == 1 && @@ -150,12 +139,11 @@ class SplitVOp : public XlaOpKernel { split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); // Get the dimension of this split. - xla::Literal split_size_literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); + std::vector split_sizes; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes)); for (int i = 0; i < num_split; ++i) { - int slice_size; - slice_size = split_size_literal.Get({i}); + int64 slice_size = split_sizes[i]; if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, @@ -164,7 +152,6 @@ class SplitVOp : public XlaOpKernel { i)); neg_one_dim = i; } else { - split_sizes_vec[i] = slice_size; total_split_size += slice_size; } } @@ -183,7 +170,7 @@ class SplitVOp : public XlaOpKernel { total_split_size)); if (neg_one_dim >= 0) { - split_sizes_vec[neg_one_dim] = + split_sizes[neg_one_dim] = input_shape.dim_size(split_dim) - total_split_size; } @@ -195,7 +182,7 @@ class SplitVOp : public XlaOpKernel { std::vector strides(input_shape.dims(), 1); for (int i = 0; i < num_split; ++i) { TensorShape output_shape(input_shape); - int slice_size = split_sizes_vec[i]; + int slice_size = split_sizes[i]; output_shape.set_dim(split_dim, slice_size); // Slice out the ith split from the split dimension. @@ -207,8 +194,8 @@ class SplitVOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("SplitV") - .CompileTimeConstInput("split_dim") - .CompileTimeConstInput("size_splits"), + .CompileTimeConstantInput("split_dim") + .CompileTimeConstantInput("size_splits"), SplitVOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index ee70f508a9586d5f47bd7bb7670506d4c92b369f..8e9e4daf99d3dd3b8e149e3f3e5f6c27665c0fcb 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -69,7 +69,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, } TensorShape stack_shape; - stack_shape.AddDim(resource->tensor_array_size()); + stack_shape.AddDim(resource->max_array_size()); stack_shape.AppendShape(elem_shape); if (!resource->initialized()) { @@ -97,10 +97,10 @@ class StackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - int64 size; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size)); + int64 max_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size)); OP_REQUIRES( - ctx, size >= 0, + ctx, max_size >= 0, errors::InvalidArgument( "XLA compilation requires a fixed stack size upper bound. If " "you are using tf.while_loop, set the maximum_iterations parameter " @@ -108,14 +108,9 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::XlaOp value; - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* resource; - string name = absl::StrCat("Stack: ", stack_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - TensorShape(), value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &resource)); + XlaResource* resource = + ctx->xla_context()->AddResource(XlaResource::CreateStack( + /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size)); ctx->SetResourceOutput(0, resource); } @@ -126,7 +121,9 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp); +REGISTER_XLA_OP( + Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(), + StackOp); class StackPushOp : public XlaOpKernel { public: @@ -173,7 +170,7 @@ class StackPushOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp); }; -REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp); +REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp); class StackPopOp : public XlaOpKernel { public: @@ -227,7 +224,7 @@ class StackPopOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp); }; -REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp); +REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp); class StackCloseOp : public XlaOpKernel { public: @@ -241,7 +238,7 @@ class StackCloseOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp); }; -REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp); +REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5412e135478361d08965e4621ec52cfb4a792f1d..50653d7b3973b73d580cdeec5d71943b575d7cc9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -17,27 +17,43 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/random.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/prng.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/math/math_util.h" namespace tensorflow { namespace { +xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { + // Mask the last 16 bit. With normal rounding, values near "maxval" would be + // converted to "maxval" which is out of range ["minval", "maxval"). In + // addition, the distribution near the limit is not uniform. + if (dtype == DT_BFLOAT16) { + xla::XlaBuilder* builder = input.builder(); + auto output = xla::BitcastConvertType(input, xla::U32) & + xla::ConstantR0(builder, 0xFFFF0000); + return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), + xla::BF16); + } else { + return input; + } +} + class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) {} + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); @@ -60,24 +76,81 @@ class StatelessRandomUniformOp : public XlaOpKernel { auto uniform = xla::StatelessRngUniform( {seed0, seed1}, xla_shape, xla::ConstantR0(builder, 0.0), xla::ConstantR0(builder, 1.0)); + uniform = MaybeConvertF32ToBF16(uniform, dtype_); ctx->SetOutput(0, uniform); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); }; // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") - .CompileTimeConstInput("shape") - .TypeConstraint("dtype", DT_FLOAT) + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); +class StatelessRandomUniformIntOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformIntOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + TensorShape minval_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), + errors::InvalidArgument("minval must be scalar, got shape ", + minval_shape.DebugString())); + TensorShape maxval_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), + errors::InvalidArgument("minval must be scalar, got shape ", + maxval_shape.DebugString())); + + xla::XlaOp seed = ctx->Input(1); + xla::XlaOp minval = ctx->Input(2); + xla::XlaOp maxval = ctx->Input(3); + + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); + + auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + + auto uniform = + xla::StatelessRngUniform({seed0, seed1}, xla_shape, minval, maxval); + ctx->SetOutput(0, uniform); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformIntOp); +}; + +// TODO(phawkins): generalize to non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomUniformInt") + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_INT32, DT_INT64}) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomUniformIntOp); + class StatelessRandomNormalOp : public XlaOpKernel { public: explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) {} + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { TensorShape shape; @@ -103,24 +176,29 @@ class StatelessRandomNormalOp : public XlaOpKernel { // sqrt(2) * erfinv(x) auto normal = xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); + normal = MaybeConvertF32ToBF16(normal, dtype_); ctx->SetOutput(0, normal); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); }; // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") - .CompileTimeConstInput("shape") - .TypeConstraint("dtype", DT_FLOAT) + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); class StatelessTruncatedNormalOp : public XlaOpKernel { public: explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx) {} + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { TensorShape shape; @@ -142,17 +220,20 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { {seed0, seed1}, xla_shape, xla::ConstantR0(builder, std::numeric_limits::min()), xla::ConstantR0(builder, 1.0)); - - ctx->SetOutput(0, TruncatedNormal(uniform)); + auto output = TruncatedNormal(uniform); + output = MaybeConvertF32ToBF16(output, dtype_); + ctx->SetOutput(0, output); } private: + DataType dtype_; + TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp); }; REGISTER_XLA_OP(Name("StatelessTruncatedNormal") - .CompileTimeConstInput("shape") - .TypeConstraint("dtype", DT_FLOAT) + .CompileTimeConstantInput("shape") + .TypeConstraint("dtype", {DT_FLOAT, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessTruncatedNormalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 2b2e3de64fd0db9d99efa46ecaf7a0fefbae6645..10d990b3213ab882cf44a4df20a977633de3fdab 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -109,9 +109,9 @@ class StridedSliceOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("StridedSlice") - .CompileTimeConstInput("begin") - .CompileTimeConstInput("end") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { @@ -218,10 +218,10 @@ class StridedSliceGradOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("StridedSliceGrad") - .CompileTimeConstInput("shape") - .CompileTimeConstInput("begin") - .CompileTimeConstInput("end") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("shape") + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), StridedSliceGradOp); class StridedSliceAssignOp : public XlaOpKernel { @@ -331,9 +331,9 @@ class StridedSliceAssignOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") - .CompileTimeConstInput("begin") - .CompileTimeConstInput("end") - .CompileTimeConstInput("strides"), + .CompileTimeConstantInput("begin") + .CompileTimeConstantInput("end") + .CompileTimeConstantInput("strides"), StridedSliceAssignOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 94108b764fd32fc77520f9a8ea16065c27e6accf..939d7e19515a1cb41e3e23e9d1fa957ae09ecab7 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -61,8 +61,8 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size() >= 0) - << resource->name() << " size " << resource->tensor_array_size(); + TF_RET_CHECK(resource->max_array_size() >= 0) + << resource->name() << " size " << resource->max_array_size(); if (!resource->initialized()) { TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); @@ -78,7 +78,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( @@ -114,7 +114,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); - shape->InsertDim(0, resource->tensor_array_size()); + shape->InsertDim(0, resource->max_array_size()); return Status::OK(); } @@ -123,9 +123,10 @@ Status GetTensorArrayShape(const XlaResource* resource, xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, absl::Span update_dims, - const xla::XlaOp& start_indices) { + const xla::XlaOp& start_indices, DataType dtype) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); - xla::XlaOp sum = xla::Add(current, update); + xla::XlaOp sum = + dtype == DT_BOOL ? xla::Or(current, update) : xla::Add(current, update); return xla::DynamicUpdateSlice(operand, sum, start_indices); } @@ -165,13 +166,10 @@ class TensorArrayOp : public XlaOpKernel { value = xla::Broadcast(zero, ta_shape.dim_sizes()); } - XlaContext& xc = XlaContext::Get(ctx); - XlaResource* var; - string name = absl::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, shape, value, /*tensor_array_size=*/size, - /*tensor_array_gradients=*/{}, &var)); + XlaResource* var = + ctx->xla_context()->AddResource(XlaResource::CreateTensorArray( + /*name=*/absl::StrCat("TensorArray: ", tensor_array_name_), dtype_, + shape, /*initial_value=*/value, /*max_array_size=*/size)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -187,7 +185,7 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstantInput("size"), TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { @@ -222,9 +220,16 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = xla::Reshape(value, slice_shape.dim_sizes()); - xla::XlaOp written = - DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - + xla::XlaOp written; + if (resource->tensor_array_multiple_writes_aggregate()) { + written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), + start_indices, dtype_); + } else { + // TODO(b/117569591): Ideally we would report an error in the case that we + // see multiple writes to the same offset. Unfortunately there is no way + // to report errors at the moment, so we silently overwrite. + written = xla::DynamicUpdateSlice(ta, update, start_indices); + } OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -391,7 +396,11 @@ class TensorArrayScatterOp : public XlaOpKernel { } if (scatter_all_elements_in_order) { - ta = xla::Add(ta, value); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, value); + } else { + ta = xla::Add(ta, value); + } } else { auto slice_dims = value_shape.dim_sizes(); slice_dims[0] = 1LL; @@ -414,7 +423,7 @@ class TensorArrayScatterOp : public XlaOpKernel { auto start_indices = xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0(b, 0), xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } } @@ -505,14 +514,13 @@ class TensorArraySplitOp : public XlaOpKernel { xla::XlaOp ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AddDim(resource->max_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES( - ctx, lengths.size() == resource->tensor_array_size(), - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size(), ")")); + OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->max_array_size(), ")")); const xla::XlaOp value = ctx->Input(1); const xla::XlaOp flow = ctx->Input(3); @@ -522,8 +530,13 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add( - ta, xla::Reshape(value, ta_shape.dim_sizes())))); + const xla::XlaOp reshape = xla::Reshape(value, ta_shape.dim_sizes()); + if (dtype_ == DT_BOOL) { + ta = xla::Or(ta, reshape); + } else { + ta = xla::Add(ta, reshape); + } + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } @@ -534,7 +547,7 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstantInput("lengths"), TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { @@ -545,8 +558,7 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = - static_cast(var->tensor_array_size()); + size_tensor.scalar()() = static_cast(var->max_array_size()); ctx->SetConstantOutput(0, size_tensor); } diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 74d4fcc425bdadb70a7bedf2487deaf6c4a4f7b9..64a24703ae1460abfedb6d9298e1e164076a199a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -79,8 +79,8 @@ class TensorListReserveOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("TensorListReserve") - .CompileTimeConstInput("element_shape") - .CompileTimeConstInput("num_elements"), + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("num_elements"), TensorListReserveOp); class EmptyTensorListOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 93d5996b5eaf10221b1d7067e7650b78cd6b8fef..e1c764f3d5c28cf0d812519e4a16786e1f2d3a3a 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -16,7 +16,9 @@ limitations under the License. // XLA-specific Tile Op. #include +#include "absl/algorithm/container.h" #include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" @@ -38,11 +41,11 @@ class TileOp : public XlaOpKernel { explicit TileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape multiples_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape multiples_shape = ctx->InputShape("multiples"); OP_REQUIRES( - ctx, IsLegacyVector(multiples_shape), + ctx, TensorShapeUtils::IsVector(multiples_shape), errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", multiples_shape.DebugString())); OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(), @@ -51,78 +54,64 @@ class TileOp : public XlaOpKernel { input_shape.dims(), " but got length ", multiples_shape.dim_size(0))); const int input_dims = input_shape.dims(); - + auto input = ctx->Input(0); // If input is a scalar then multiples has 0 elements and this is // a NoOp. if (input_dims == 0) { - ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(0, input); return; } - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - - // zero_element_result is true if the final shape has 0 elements, - // i.e. if any of the input dimensions or multiples is zero. - std::vector multiples_array(input_dims); - std::vector output_shape; - bool all_multiples_are_one = true; - bool one_dimension_is_broadcasted_without_multiple = true; - for (int i = 0; i < input_dims; ++i) { - int multiple = literal.Get({i}); - OP_REQUIRES(ctx, multiple >= 0, + std::vector multiples; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("multiples", &multiples)); + std::vector output_dims(input_shape.dims()); + for (int64 i = 0; i < input_shape.dims(); ++i) { + OP_REQUIRES(ctx, multiples[i] >= 0, errors::InvalidArgument("Expected multiples[", i, - "] >= 0, but got ", multiple)); - int64 new_dim = input_shape.dim_size(i) * multiple; - output_shape.push_back(new_dim); - multiples_array[i] = multiple; - all_multiples_are_one = all_multiples_are_one && multiple == 1; - // If the multiple of a non-one dimensions is not one, then binary - // operation broadcast semantics will not be sufficient to implement the - // tile operation. - one_dimension_is_broadcasted_without_multiple = - one_dimension_is_broadcasted_without_multiple && - ((input_shape.dim_size(i) > 1 && multiple == 1) || - input_shape.dim_size(i) == 1); + "] >= 0, but got ", output_dims[i])); + output_dims[i] = input_shape.dim_size(i) * multiples[i]; } - auto input = ctx->Input(0); + // If all multiples are 1, than the input is the same as the output. - if (all_multiples_are_one) { + if (absl::c_all_of(multiples, + [](int64 multiple) { return multiple == 1; })) { ctx->SetOutput(0, input); return; } - if (one_dimension_is_broadcasted_without_multiple) { + + bool can_tile_with_implicit_broadcast = true; + for (int i = 0; i < input_dims; ++i) { + int64 multiple = multiples[i]; + // If the multiple and input dimension are not 1, then tile cannot be + // implemented with a single hlo broadcast. + if (multiple != 1 && input_shape.dim_size(i) != 1) { + can_tile_with_implicit_broadcast = false; + } + } + + if (can_tile_with_implicit_broadcast) { // Create a constant Zero the size of the output shape to leverage binary // operation broadcast semantics. auto broadcasted_zero = xla::Broadcast( - XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape); - ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_dims); + if (ctx->input_type(0) == DT_BOOL) { + ctx->SetOutput(0, xla::Or(broadcasted_zero, input)); + } else { + ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); + } return; } - // First broadcast the requisite number of multiples along each - // dimension. This prepends the broadcasted dimensions, so an - // input of shape [2,3,1] broadcast with multiples [5,4,3] will - // end up with shape [5,4,3,2,3,1]. - auto broadcasted = xla::Broadcast(input, multiples_array); - // Now flatten and reshape. The broadcasted dimensions are - // paired with the original dimensions so in the above example - // we flatten [0,3,1,4,2,5] then reshape to [10,12,3]. - std::vector flattened; - for (int i = 0; i < output_shape.size(); ++i) { - flattened.push_back(i); - flattened.push_back(i + output_shape.size()); - } - xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape); - - ctx->SetOutput(0, output); + auto result = BroadcastTo(ctx->Input("input"), output_dims); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp); +REGISTER_XLA_OP(Name("Tile").CompileTimeConstantInput("multiples"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 183879c7602ccbbd74fca6cb9fa3fc94c066c37d..ee3bdf3394e37c757f31724e73e95417becaa534 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/lib/sorting.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -59,7 +58,7 @@ class TopKOp : public XlaOpKernel { bool sorted_; }; -REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint( +REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstantInput("k").TypeConstraint( "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}), TopKOp); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 7077c2e3a546e198bdb4ff944ea531f3158810f2..960c1462ceb8c00a2d6c96564f6c985fd1caef0f 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -320,9 +320,8 @@ class ResourceApplyAdagradDA : public XlaOpKernel { xla::XlaOp lr = ctx->Input(4); xla::XlaOp l1 = ctx->Input(5); xla::XlaOp l2 = ctx->Input(6); - xla::XlaBuilder* const b = ctx->builder(); xla::XlaOp global_step = - XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_); + XlaHelpers::ConvertElementType(ctx->Input(7), dtype_); accum = accum + grad; squared_accum = squared_accum + xla::Square(grad); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 6b303b31d43ce2249a87f25723caf34f84c8387d..c9b324a243e4cc3ec64daa3ca0d285336a0d0154 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -37,8 +37,8 @@ class TransposeOp : public XlaOpKernel { : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape perm_tensor_shape = ctx->InputShape(1); + const TensorShape input_shape = ctx->InputShape("x"); + const TensorShape perm_tensor_shape = ctx->InputShape("perm"); // Preliminary validation of sizes. OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), @@ -52,19 +52,15 @@ class TransposeOp : public XlaOpKernel { ". But input(1) is a vector of size ", perm_tensor_shape.num_elements())); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); - - std::vector perm(dims); - std::copy(literal.data().begin(), literal.data().end(), - perm.begin()); + std::vector perm; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector("perm", &perm)); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { - const int32 d = perm[i]; + const int64 d = perm[i]; OP_REQUIRES( ctx, 0 <= d && d < dims, errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); @@ -83,9 +79,9 @@ class TransposeOp : public XlaOpKernel { xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - transposed = ctx->Input(0); + transposed = ctx->Input("x"); } else { - transposed = xla::Transpose(ctx->Input(0), transposed_order); + transposed = xla::Transpose(ctx->Input("x"), transposed_order); } // Conjugate the transposed result if this is ConjugateTransposeOp. @@ -106,9 +102,10 @@ class ConjugateTransposeOp : public TransposeOp { : TransposeOp(ctx, /*conjugate=*/true) {} }; -REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); +REGISTER_XLA_OP(Name("Transpose").CompileTimeConstantInput("perm"), + TransposeOp); -REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"), +REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"), ConjugateTransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. @@ -153,7 +150,7 @@ class InvertPermutationOp : public XlaOpKernel { REGISTER_XLA_OP(Name("InvertPermutation") .TypeConstraint("T", DT_INT32) - .CompileTimeConstInput("x"), + .CompileTimeConstantInput("x"), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0bdfc05726105e2d18362a691cbe2aab00bf77f3..a0ea6422d732b00fc1b8cf855d9c9ad603b87c82 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -80,24 +80,8 @@ XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); XLAJIT_MAKE_UNARY(Neg, -x); -// Implements Banker's rounding: numbers that are equidistant between two -// integers are rounded towards even. -xla::XlaOp RoundToEven(xla::XlaOp x) { - auto half = xla::ScalarLike(x, 0.5); - auto one = xla::ScalarLike(x, 1.0); - auto two = xla::ScalarLike(x, 2.0); - - auto round_val = xla::Floor(x); - auto fraction = x - round_val; - auto nearest_even_int = round_val - two * xla::Floor(half * x); - auto is_odd = xla::Eq(nearest_even_int, one); - return xla::Select(xla::Or(xla::Gt(fraction, half), - xla::And(xla::Eq(fraction, half), is_odd)), - round_val + one, round_val); -} - -XLAJIT_MAKE_UNARY(Rint, RoundToEven(x)); -XLAJIT_MAKE_UNARY(Round, RoundToEven(x)); +XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x)); +XLAJIT_MAKE_UNARY(Round, xla::RoundToEven(x)); XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x)); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 559414eeaa5fec75e5a9d1866baaf738c024cd15..ce007fc04a818869686b9936a1607cee42665e87 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -64,7 +64,7 @@ Status MakeXlaCompilerArgumentsFromInputs( if (!arg.initialized) { *has_uninitialized_vars = true; } - arg.tensor_array_size = resource->tensor_array_size(); + arg.max_array_size = resource->max_array_size(); for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index 412afeaaad96842521fbd306f5b666e837e675fd..ad8e707e1116d01d492575986a7ab9586022f6b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -89,13 +89,10 @@ class XlaBroadcastHelperOp : public XlaOpKernel { lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); broadcast_shape[dim] = min_rank_shape->dim_size(i); } - xla::PrimitiveType type = context->input_xla_type(0); - xla::Shape broadcast_xla_shape = - xla::ShapeUtil::MakeShape(type, broadcast_shape); if (broadcast_lhs) { - lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + lhs = xla::BroadcastInDim(lhs, broadcast_shape, broadcast_dims); } else { - rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + rhs = xla::BroadcastInDim(rhs, broadcast_shape, broadcast_dims); } context->SetOutput(0, lhs); context->SetOutput(1, rhs); @@ -108,7 +105,7 @@ class XlaBroadcastHelperOp : public XlaOpKernel { }; REGISTER_XLA_OP( - Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + Name("XlaBroadcastHelper").CompileTimeConstantInput("broadcast_dims"), XlaBroadcastHelperOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index fecc7c556eb4121b912796e5811632c46769b479..4612f19971a3ce6994aef303f751748b77ccda9a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -90,11 +90,11 @@ class XlaConvOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("XlaConv") - .CompileTimeConstInput("window_strides") - .CompileTimeConstInput("lhs_dilation") - .CompileTimeConstInput("rhs_dilation") - .CompileTimeConstInput("feature_group_count") - .CompileTimeConstInput("padding"), + .CompileTimeConstantInput("window_strides") + .CompileTimeConstantInput("lhs_dilation") + .CompileTimeConstantInput("rhs_dilation") + .CompileTimeConstantInput("feature_group_count") + .CompileTimeConstantInput("padding"), XlaConvOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index 59502d83c7338bd1b05b3323a97761fff2da186a..a3c2eef993c80e43e7cf9e1f6147e5b337c41cfe 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -96,9 +96,9 @@ class XlaPadOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("XlaPad") - .CompileTimeConstInput("padding_low") - .CompileTimeConstInput("padding_high") - .CompileTimeConstInput("padding_interior"), + .CompileTimeConstantInput("padding_low") + .CompileTimeConstantInput("padding_high") + .CompileTimeConstantInput("padding_interior"), XlaPadOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc index 089776fcf74fcf6b363dfff5de8d86d7449eacd6..9043af995386a179f74d95bbc6c17a1cac7881cd 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -138,9 +138,9 @@ class XlaSelectAndScatterOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("XlaSelectAndScatter") - .CompileTimeConstInput("window_dimensions") - .CompileTimeConstInput("window_strides") - .CompileTimeConstInput("padding"), + .CompileTimeConstantInput("window_dimensions") + .CompileTimeConstantInput("window_strides") + .CompileTimeConstantInput("padding"), XlaSelectAndScatterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 8597e7f139d8d32b7e08782e70a4ee44d02618f2..3e7a761120317ff85947559b7b2e52be9232afb7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -18,16 +18,18 @@ filegroup( load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") cc_library( - name = "batch_dot", - srcs = ["batch_dot.cc"], - hdrs = ["batch_dot.h"], + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -36,8 +38,6 @@ cc_library( srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], deps = [ - ":batch_dot", - ":triangular_solve", ":util", ":while_loop", "//tensorflow/compiler/xla:literal", @@ -47,6 +47,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/core:lib", ], ) @@ -71,7 +74,6 @@ cc_library( srcs = ["qr.cc"], hdrs = ["qr.h"], deps = [ - ":batch_dot", ":util", ":while_loop", "//tensorflow/compiler/xla:literal_util", @@ -83,7 +85,8 @@ cc_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/core:lib", ], ) @@ -108,51 +111,6 @@ cc_library( ], ) -cc_library( - name = "triangular_solve", - srcs = ["triangular_solve.cc"], - hdrs = ["triangular_solve.h"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:numeric", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "triangular_solve_test", - srcs = ["triangular_solve_test.cc"], - tags = ["noasan"], # sometimes times out, http://b/78650012 - deps = [ - ":triangular_solve", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "util", srcs = ["util.cc"], @@ -171,29 +129,6 @@ cc_library( ], ) -xla_test( - name = "util_test", - srcs = ["util_test.cc"], - deps = [ - ":batch_dot", - ":util", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "while_loop", srcs = ["while_loop.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc deleted file mode 100644 index 5400e8834cb9807f6dd71abe7789b2672e29e905..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" - -#include -#include - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { - return errors::InvalidArgument( - "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(x_shape), " vs. ", - xla::ShapeUtil::HumanString(y_shape)); - } - const int ndims = xla::ShapeUtil::Rank(x_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to BatchedDot must have rank >= 2: ", ndims); - } - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector batch_dimension_numbers; - for (int i = 0; i < ndims - 2; ++i) { - if (x_shape.dimensions(i) != y_shape.dimensions(i)) { - return errors::InvalidArgument( - "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " vs ", - xla::ShapeUtil::HumanString(y_shape)); - } - batch_dimension_numbers.push_back(i); - } - - int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); - int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { - return errors::InvalidArgument( - "Dimensions ", x_inner_dim, " and ", y_inner_dim, - " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(y_shape), - " transpose: ", transpose_y); - } - - // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::IsZeroElementArray(x_shape) || - xla::ShapeUtil::IsZeroElementArray(y_shape)) { - std::vector dimensions(batch_dimension_numbers.size()); - for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); - } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dimensions(x_outer_dim)); - dimensions.push_back(y_shape.dimensions(y_outer_dim)); - return xla::Broadcast( - xla::ConstantLiteral(builder, - xla::LiteralUtil::Zero(x_shape.element_type())), - dimensions); - } - - if (x_shape.element_type() == xla::C64 && conjugate_x) { - x = xla::Conj(x); - } - if (y_shape.element_type() == xla::C64 && conjugate_y) { - y = xla::Conj(y); - } - - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - - xla::DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); - dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); - for (auto batch_dimension_number : batch_dimension_numbers) { - dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); - dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); - } - - return xla::DotGeneral(x, y, dot_dnums, &precision_proto); - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h deleted file mode 100644 index 6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace tensorflow { - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. Each of the -// individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each -// can be elementwise-complex-conjugated by setting the `conjugate_x` or -// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both -// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 0000000000000000000000000000000000000000..be31f116686a2e302ece730e9d03312a45888a61 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector broadcast_dims; + std::vector broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = + xla::BroadcastInDim(input, broadcast_shape, broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..591e696f06b994a7fdea58bc95ba785f683ce7d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index ab3d0a566839343828d176d9a46672824e425613..550ab5b05693b79e60e49577309328ac6846d3f9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -101,10 +102,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // a[..., i, i] auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto diag_dot = BatchDot(row, TransposeInMinorDims(row), precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -122,10 +120,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, // The columns in [i, n] are zeroed out in `row`, so we just have to // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], // r.T) - auto dot = BatchDot(body_l, row, - /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto dot = BatchDot(body_l, TransposeInMinorDims(row), precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -185,9 +180,7 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto delta = BatchDot(lhs, TransposeInMinorDims(rhs), precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 6b3f2b6e065b5c99e2d0248237369ecc30188aa5..d6007748609fdd161cb89692a167eb7ed12fe00c 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -191,12 +191,8 @@ xla::StatusOr QRBlock( auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = - BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - vva = - BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto vva = BatchDot(v_broadcast, a, precision); + vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -278,12 +274,9 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto yv = BatchDot(TransposeInMinorDims(y), v, precision); // wyv has shape [..., m, 1] - auto wyv = - BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto wyv = BatchDot(w, yv, precision); auto z = xla::Mul( -beta, v + wyv, @@ -375,23 +368,15 @@ xla::StatusOr QRDecomposition( // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = - BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - a_update = - BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + auto a_update = BatchDot(TransposeInMinorDims(w), a_panel, precision); + a_update = BatchDot(y, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = - BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - q_update = BatchDot(q_update, y, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); + auto q_update = BatchDot(q_panel, w, precision); + q_update = BatchDot(q_update, TransposeInMinorDims(y), precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc deleted file mode 100644 index 6524c2a9b1ada632d80edd234272760c2b545cc4..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ /dev/null @@ -1,416 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" - -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/math/math_util.h" - -namespace tensorflow { - -// Get the diagonal blocks of the coefficient matrix -xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(a)); - int ndims = xla::ShapeUtil::Rank(shape); - int64 n = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = n / block_size; - - xla::XlaOp diag_blocks; - - // If the coefficient matrix is exactly the block size, we just add a - // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] - if (n == block_size) { - std::vector permutation(ndims); - std::iota(permutation.begin(), permutation.end(), 1); - permutation.insert(permutation.end() - 2, 0); - return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation); - } - - // We can grab entire blocks using gather - if (n > block_size) { - // Construct the starting indices of the diagonal blocks - auto start_indices = - Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), - xla::ConstantR0(builder, block_size)), - /*broadcast_sizes=*/{2}), - /*permutation=*/{1, 0}); - - // Gather the diagonal blocks - xla::GatherDimensionNumbers dim_numbers; - dim_numbers.add_offset_dims(ndims - 1); - dim_numbers.add_offset_dims(ndims); - dim_numbers.add_start_index_map(ndims - 2); - dim_numbers.add_start_index_map(ndims - 1); - dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, start_indices, dim_numbers, - /*slice_sizes=*/{block_size, block_size}); - } - - // The last block might be smaller than the block size, - // so we will need to pad it - if (n % block_size != 0) { - // Pad with zeros - auto last_blocks = - SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); - xla::PaddingConfig config = xla::MakeNoPaddingConfig(ndims); - int64 padding = block_size - n % block_size; - config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); - config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); - last_blocks = - Pad(last_blocks, Zero(builder, shape.element_type()), config); - - // Add a singleton dimension - // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(last_blocks)); - auto shape_dims = xla::AsInt64Slice(blocks_shape.dimensions()); - auto last_blocks_dims = std::vector(ndims); - std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); - last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); - last_blocks = Reshape(last_blocks, last_blocks_dims); - - // Concatenate with the other blocks if necessary - if (n > block_size) { - diag_blocks = - xla::ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); - } else { - diag_blocks = last_blocks; - } - } - - return diag_blocks; - }); -} - -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = diag_blocks.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // Input is a batch of square lower triangular square matrices. Its shape is - // (..., size, size). We resize this to (num_blocks, size, size). - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(diag_blocks)); - int64 block_size = xla::ShapeUtil::GetDimension(shape, -1); - int64 num_blocks = xla::ShapeUtil::ElementsIn(shape) / - tensorflow::MathUtil::IPow(block_size, 2); - diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); - - // The input must be triangular because we rely on that when doing - // multiplications later on - diag_blocks = Triangle(diag_blocks, /*lower=*/lower); - - // Rescale blocks to be unit triangular, but avoid dividing by - // zero (which can happen if the last block was padded) otherwise it will - // introduce nans which will propagate - auto diags = GetMatrixDiagonal(diag_blocks); - TF_ASSIGN_OR_RETURN(xla::Shape diags_shape, builder->GetShape(diags)); - auto one = ScalarLike(diags, 1); - auto ones = Broadcast(one, xla::AsInt64Slice(diags_shape.dimensions())); - diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); - auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); - - // We can now use the fact that for an upper triangular matrix - // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have - // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks - // have been rescaled to be unit triangular, so L22 = L22' = 1. - - // Initialize the output matrix with -1s on the diagonal. We use -1 instead - // of 1 because we cannot do matrix-vector multiplies with variable shapes - // inside of a loop, or do irregularly shaped in-place updates. Hence, - // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the - // entire row i.e. we calculate - // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) - // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. - auto identity = - IdentityMatrix(builder, shape.element_type(), block_size, block_size); - auto neg_identity = -identity; - - // The first or last diagonal element should be set to 1 instead of -1 - // though, since we never update it - auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); - auto start_index = (lower) ? 0 : block_size - 1; - auto output_block = DynamicUpdateSlice( - neg_identity, pos_one, - /*start_indices=*/xla::ConstantR1(builder, 2, start_index)); - - // Broadcast diag([1, -1, -1, ...]) to every block - xla::XlaOp output = Broadcast(output_block, - /*broadcast_sizes=*/{num_blocks}); - - // Now we construct a loop that performs matrix-vector multiplications - // inverting the blocks one row at a time - std::vector tuple_shapes = { - // The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The output has the shape of A, with one row updated each iteration. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size}), - // The input is a loop invariant. - xla::ShapeUtil::MakeShape(shape.element_type(), - {num_blocks, block_size, block_size})}; - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init_i = One(builder, xla::S32); - auto init = xla::Tuple(builder, {init_i, output, scaled_diag_blocks}); - - // Construct the loop condition function. - std::unique_ptr condb = - builder->CreateSubBuilder("InvertDiagCond"); - { - auto i = GetTupleElement( - Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); - Lt(i, xla::ConstantR0(condb.get(), block_size)); - } - TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - - // Construct the loop body function. - std::unique_ptr bodyb = - builder->CreateSubBuilder("InvertDiagBody"); - { - auto input_tuple = - Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); - - auto i = GetTupleElement(input_tuple, 0); - auto body_out = GetTupleElement(input_tuple, 1); - auto body_input = GetTupleElement(input_tuple, 2); - - auto zero = xla::ConstantR1(bodyb.get(), 1, 0); - auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; - auto start_indices = - xla::ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); - auto input_row = - DynamicSlice(body_input, start_indices, - /*slice_sizes=*/{num_blocks, 1, block_size}); - - // We want -L21 L11^{-1} - xla::DotDimensionNumbers dnums; - dnums.add_lhs_batch_dimensions(0); - dnums.add_rhs_batch_dimensions(0); - dnums.add_lhs_contracting_dimensions(2); - dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfig precision_proto; - precision_proto.add_operand_precision(precision); - precision_proto.add_operand_precision(precision); - auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); - - body_out = DynamicUpdateSlice(body_out, update, start_indices); - - auto next_i = i + ScalarLike(i, 1); - xla::Tuple(bodyb.get(), {next_i, body_out, body_input}); - } - TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); - - // Construct the While loop and return the result, - // return while_loop(cond_fun, body_fun, init)[1] - auto invert_while = While(cond, body, init); - auto inv_diag_blocks = GetTupleElement(invert_while, 1); - - // Undo the scaling - inv_diag_blocks = Div(inv_diag_blocks, diags, - /*broadcast_dimensions=*/{0, 1}); - - // Reshape back to original batch major dimensions - return Reshape(inv_diag_blocks, xla::AsInt64Slice(shape.dimensions())); - }); -} - -xla::XlaOp SolveWithInvertedDiagonalBlocks( - xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, - builder->GetShape(inv_diag_blocks)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - int64 block_size = xla::ShapeUtil::GetDimension(blocks_shape, -1); - - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - int64 ndims = xla::ShapeUtil::Rank(a_shape); - int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - int64 num_blocks = n / block_size + (n % block_size != 0); - int64 m_dim = (left_side) ? -1 : -2; - int64 m = xla::ShapeUtil::GetDimension(b_shape, m_dim); - - // Initialize the solution - auto x = ZerosLike(b); - - // This loop is unrolled for performance reasons, but it could be expressed - // rolled as well since the matrices are of the same size each iteration - for (int i = 0; i < num_blocks; i++) { - // High-level intuition: We have B[i] = L[i] @ X. Since L is upper - // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split - // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which - // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] - - // Decide whether we go from first block to last or vice versa - auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; - - // Get the size of the inverse blocks (the last one might be smaller) - int64 block = (n % block_size != 0 && j + 1 == num_blocks) - ? n % block_size - : block_size; - auto inv_block = - MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0}, - {j + 1, block, block}), - /*dimensions=*/{ndims - 2, ndims - 1}), - conjugate_a); - - // Get the corresponding row of B - int64 k = std::min((j + 1) * block_size, n); - std::vector start = {j * block_size, 0}; - std::vector end = {k, m}; - if (!left_side) { - std::swap(start[0], start[1]); - std::swap(end[0], end[1]); - } - auto b_row = SliceInMinorDims(b, start, end); - - xla::XlaOp remainder; - if (i == 0) { - remainder = b_row; - } else { - // This matrix multiply involves a lot of multiplying with zero (namely, - // X[i * block_size:] = 0), but this is faster than slicing... - end = {k, n}; - if (!left_side) { - std::swap(end[0], end[1]); - } - if (transpose_a) { - std::swap(start[0], start[1]); - std::swap(end[0], end[1]); - } - auto a_row = - MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); - if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); - } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a, - /*conjugate_x=*/false, - /*conjugate_y=*/false, precision); - } - } - - xla::XlaOp x_update; - auto zero = Zero(builder, xla::S32); - auto start_index = - xla::ConstantR0WithType(builder, xla::S32, j * block_size); - std::vector update_starts = {start_index, zero}; - if (left_side) { - x_update = - BatchDot(inv_block, remainder, transpose_a, false, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - } else { - x_update = - BatchDot(remainder, inv_block, false, transpose_a, - /*conjugate_x=*/false, /*conjugate_y=*/false, precision); - std::swap(update_starts[0], update_starts[1]); - } - x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); - } - - return x; - }); -} - -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size, - xla::PrecisionConfig::Precision precision) { - xla::XlaBuilder* builder = a.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); - if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(a_shape), " vs. ", - xla::ShapeUtil::HumanString(b_shape)); - } - const int64 ndims = xla::ShapeUtil::Rank(a_shape); - if (ndims < 2) { - return errors::InvalidArgument( - "Arguments to TriangularSolve must have rank >= 2: ", ndims); - } - // The batch dimensions must be equal. - std::vector batch_dimensions; - for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape.dimensions(i); - int64 b_size = b_shape.dimensions(i); - if (a_size != b_size) { - return errors::InvalidArgument( - "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - batch_dimensions.push_back(a_size); - } - - if (xla::ShapeUtil::GetDimension(a_shape, -1) != - xla::ShapeUtil::GetDimension(a_shape, -2)) { - return errors::InvalidArgument( - "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(a_shape)); - } - const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { - return errors::InvalidArgument( - "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(a_shape), " vs ", - xla::ShapeUtil::HumanString(b_shape)); - } - - if (block_size < 1) { - return errors::InvalidArgument( - "block_size argument to TriangularSolve must be >= 1; got ", - block_size); - } - - // We find the diagonal blocks of the coefficient matrix - auto diag_blocks = DiagonalBlocks(a, block_size); - - // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, - conjugate_a, precision); - - // We now find the solution using GEMMs - auto x = - SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, - transpose_a, conjugate_a, precision); - - return x; - }); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 804671fbc75b0a5a6e04b204822b6f084013cd8b..c0bd172d17c192435ba8ee196f9def0491c0bf5c 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,36 +113,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_RET_CHECK(start.size() == end.size()); - int64 n_minor_dims = start.size(); - - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); - - // Prepends 0s in the major dim - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + major_dims.size()); - - // Prepends the shape of the major dims. - std::vector padded_end(n_dims); - std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); - std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); - - std::vector strides(n_dims, 1); - return xla::Slice(x, padded_start, padded_end, strides); - }); -} std::vector ConcatVectors(absl::Span xs, absl::Span ys) { @@ -152,100 +122,4 @@ std::vector ConcatVectors(absl::Span xs, return output; } -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - int64 n_minor_dims = starts.size(); - TF_RET_CHECK(n_minor_dims == sizes.size()); - TF_RET_CHECK(n_minor_dims <= n_dims); - auto major_dims = xla::AsInt64Slice(shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - sizes.size()); - auto padded_starts = PrependZerosInMajorDims(x, starts); - auto padded_sizes = ConcatVectors(major_dims, sizes); - return xla::DynamicSlice(x, padded_starts, padded_sizes); - }); -} - -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - auto start_constant = xla::ConstantR1(builder, start_as_int32); - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, - builder->GetShape(start_constant)); - const int64 start_length = - xla::ShapeUtil::GetDimension(start_constant_shape, -1); - TF_RET_CHECK(start_length == n_dims); - return xla::DynamicUpdateSlice(x, update, start_constant); - }); -} - -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - const int64 n_minor_dims = start.size(); - TF_RET_CHECK(n_minor_dims <= n_dims); - std::vector padded_start(n_dims, 0); - std::copy(start.begin(), start.end(), - padded_start.begin() + (n_dims - n_minor_dims)); - return UpdateSlice(x, update, padded_start); - }); -} - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts) { - auto padded_starts = PrependZerosInMajorDims(x, starts); - return xla::DynamicUpdateSlice(x, update, padded_starts); -} - -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - auto zero = xla::Reshape(xla::ConstantR0(builder, 0), {1}); - std::vector padded_starts(n_dims, zero); - for (int i = 0; i < starts.size(); ++i) { - padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1}); - } - return xla::ConcatInDim(builder, padded_starts, 0); - }); -} - -xla::XlaOp TransposeInMinorDims(xla::XlaOp x) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - std::vector permutation(n_dims); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); - return xla::Transpose(x, permutation); - }); -} - -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) { - xla::XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { - TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - auto perform_conj = shape.element_type() == xla::C64 && conjugate; - return perform_conj ? xla::Conj(x) : x; - }); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 80e9e5b002d49581209e608b98606e02709c5876..aec8061cb4322b8d315b6cdc80c7fff1e0cb4cb1 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -38,44 +38,10 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, int64 value); -// Builds a vector of zeros of length rank(x) with the last values being -// those in `starts`. -xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - absl::Span starts); - -// Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, - absl::Span end); - // Returns the concatenation of `xs` and `ys`. std::vector ConcatVectors(absl::Span xs, absl::Span ys); -// Performs a dynamic slice in the minor dimensions of a Tensor. -xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - absl::Span starts, - absl::Span sizes); - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0], ..., start[n]] = update -xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span start); - -xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - absl::Span starts); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Applies a complex conjugation operation if `a` is complex and `conjugate_a` -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc deleted file mode 100644 index 442fe92c34ca26cb1a854cc90da8dc034bca79bb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/lib/util.h" - -#include -#include -#include - -#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace tensorflow { -namespace { - -using UtilTest = xla::ClientLibraryTestBase; -using UtilLeftLookingTest = xla::ClientLibraryTestBase; - -xla::Array2D BValsRight() { - return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; -} - -xla::Array2D BValsLeft() { - return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; -} - -xla::Array2D AValsFull() { - return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; -} - -xla::Array3D BatchedAValsFull() { - return {{ - {2, 0, 1, 2}, - {3, 6, 0, 1}, - {4, 7, 9, 0}, - {5, 8, 10, 11}, - }, - { - {16, 24, 8, 12}, - {24, 61, 82, 48}, - {8, 82, 456, 106}, - {12, 48, 106, 62}, - }}; -} - -XLA_TEST_F(UtilTest, Simple2dLookup) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, x, y; - auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); - auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); - auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); - DynamicSliceInMinorDims(a, {x, y}, {1, 1}); - - ComputeAndCompareR2(&builder, {{10}}, - {a_data.get(), x_data.get(), y_data.get()}, - xla::ErrorSpec(1e-2, 1e-2)); -} - -XLA_TEST_F(UtilTest, Simple3dLookup) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); - - DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, - {1, 4}); - - ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, - {a_data.get(), index_data.get()}); -} - -XLA_TEST_F(UtilTest, SimpleSliceUpdate) { - xla::XlaBuilder builder(TestName()); - - xla::XlaOp a, b, x, y; - auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); - auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); - auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); - auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); - - DynamicUpdateSliceInMinorDims(a, b, {x, y}); - - xla::Array2D expected( - {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); - - ComputeAndCompareR2( - &builder, expected, - {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); -} - -XLA_TEST_F(UtilTest, RowBatchDot) { - xla::XlaBuilder builder(TestName()); - - int n = 4; - - xla::XlaOp a, row, index; - auto a_data = - CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); - auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, - "row", &builder, &row); - // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). - auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); - - auto l_index = DynamicSliceInMinorDims( - a, {index, xla::ConstantR0(&builder, 0)}, {1, n}); - BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true); - - ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, - {a_data.get(), row_data.get(), index_data.get()}); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 20103ec3ae00b57723e05326dbbb1b0f6e1a671a..67d08290033361f16dfff42b06af9b253e84963a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -32,6 +32,12 @@ Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, return Status::OK(); } +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal)); + return literal.Clone(); +} + Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { xla::Shape xla_shape; diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 1db7470ee2a839099454b772d4833492e033bc92..a153dddee6127ff9c0858220f2d8a735ab3f0e19 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -30,6 +30,11 @@ namespace tensorflow { // 'host_tensor'. Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, xla::BorrowingLiteral* literal); + +// Returns a Literal with the contents of 'host_tensor', backed by its own +// storage (i.e., not reusing 'host_tensor's buffers.) +xla::StatusOr HostTensorToLiteral(const Tensor& host_tensor); + // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // owned by 'host_tensor', but is mutable via the xla::Literal methods. Status HostTensorToMutableBorrowingLiteral( diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 733eeed3c661c9ed683f0fb7fd90f7f997b8dc2b..bd2c0a5ee88869ba60701c0a7ace05857452eed9 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") @@ -354,12 +356,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 69ca39436013ec5cf09ba502a1540d5df322e213..fef97b98c376d9df8bbfd9cb6651216895e46bf4 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,9 +1,13 @@ licenses(["notice"]) # Apache 2.0 +package_group( + name = "friends", + includes = ["//tensorflow:internal"], +) + package( default_visibility = [ - "//learning/tfx:__subpackages__", - "//tensorflow:internal", + ":friends", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27dd18a9bbd5aceece41aaf61eb185acb537b3b6..147e562658bbfc445f99268812e2c3ae1ee61e30 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -250,7 +250,7 @@ def conv(lhs, rhs_dilation: dilation to apply between kernel elements dimension_numbers: a `ConvolutionDimensionNumbers` proto. feature_group_count: number of feature groups for grouped convolution. - precision_config: a `PrecisionConfigProto` proto. + precision_config: a `xla.PrecisionConfig` proto. name: an optional name for the operator Returns: @@ -320,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -332,22 +334,27 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) @@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index f7e34a5b40c2f9244c029ed325a76322b8cf54dd..0b231ea8e7a2d8e303e91911e2e0a36fc83e78b4 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 6cd7b24592f30d7202b985f3dfd082ea2d85e344..b233e6b2c28e1968bb74901fc684e808ae45ab60 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "absl/strings/numbers.h" #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { @@ -64,4 +65,28 @@ bool HasSideEffectingNodes(const Graph& g) { return false; } +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core) { + for (const auto& hc_core : list_from_attr) { + std::vector parts = str_util::Split(hc_core, ":"); + if (parts.size() != 2) { + return errors::InvalidArgument( + "Malformed host_compute_core entry ", hc_core, + " should be :."); + } + int core; + if (!absl::numbers_internal::safe_strto32_base(parts[1], &core, 10)) { + return errors::InvalidArgument("Malformed host_compute_core entry ", + hc_core, + " part after ':' should be an integer."); + } + if (host_compute_core->find(parts[0]) != host_compute_core->end()) { + return errors::InvalidArgument( + "Duplicate host_compute_core entry for cluster ", parts[0]); + } + (*host_compute_core)[parts[0]] = core; + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index ad07624729f0b0d2443b2fc43d32dfa3377ce115..f22ddb2f58e1fa5c10ca0fdb956d9136942388b7 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -42,6 +42,12 @@ std::set CalculateTokenInputsForOutputToken(const Graph& g); // Returns whether a graph contains side-effecting nodes. bool HasSideEffectingNodes(const Graph& g); +// Parse the mapping from outside_compilation_subgraph name to core number, +// which is specified in an attr as a list of strings +// :. +Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index f31bfb45a2f4db270446eb59259969dc0ab63a8e..3c6c9a91b6d2fb47f6dee1c347e9b852f1eea3ec 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,12 +40,4 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } -std::unordered_map BuildNodeIndex(const Graph& graph) { - std::unordered_map index; - for (Node* node : graph.nodes()) { - index[node->name()] = node; - } - return index; -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 350a868568531c0d073e0cf600327d1ff9d62e3a..4ffc94ae3bc7c930720cd625a7856443c77be666 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -44,9 +44,6 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); -// Builds a map from node name to Node* for `graph`. -std::unordered_map BuildNodeIndex(const Graph& graph); - } // namespace tensorflow // Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b22d53805d83069052cc5e16020d6c540d618a82..9fac16a9700419b189bf5393c2b8bd7d76c6c1cc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -218,7 +218,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { const Node* dup = insert_result.first->second; return errors::InvalidArgument( "Multiple ", kArgOp, " nodes with index ", index, ", ", - n->DebugString(), " and ", dup->DebugString()); + FormatNodeForError(*n), " and ", FormatNodeForError(*dup)); } } } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 01dd3ba10fec85e6b1d411fbd32fbf9c58b5fe11..cc81772e8c5da710bc733f7e4f5fe820b2c2d110 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -76,6 +76,222 @@ Status CheckFeedFetchNameConflicts(const string& kind, return Status::OK(); } +// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to +// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`. +Status CopyAssociatedFunctions(Graph* g, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + for (Node* n : g->op_nodes()) { + for (const auto& associated_function : + GetAssociatedFunctions(*n, lookup_fld)) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + const FunctionDef* fdef = + lookup_fld->Find(associated_function.func_name()); + if (!fdef) { + return errors::Internal( + "Cannot find function ", associated_function.func_name(), + " for function call node ", n->DebugString()); + } + TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef)); + break; + } + case AssociatedFunctionInfo::kSymbolicGradient: + case AssociatedFunctionInfo::kFunctionAttr: + break; + } + } + } + return Status::OK(); +} + +// For graph `g`, replaces _Arg nodes whose "index" attribute is in +// `const_input_index_to_node` with Const nodes. +Status ReplaceArgUsageWithConstNode( + Graph* g, + const std::unordered_map& const_input_index_to_node) { + // Collect all _Arg nodes. + std::unordered_map arg_nodes; + for (Node* n : g->op_nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + arg_nodes[index] = n; + } + } + + for (const auto& iter : const_input_index_to_node) { + int arg_index = iter.first; + Node* const_node = g->CopyNode(iter.second); + Node* arg_node = arg_nodes[arg_index]; + + // Collect all usages of the _Arg node. + struct OutEdgeInfo { + int dst_node_id, dst_input; + }; + std::vector usages; + for (const Edge* e : arg_node->out_edges()) { + if (e->IsControlEdge()) { + continue; + } + usages.push_back({e->dst()->id(), e->dst_input()}); + } + + for (int i = 0; i < usages.size(); i++) { + // Make a copy of `usage_node`, and change its input to const node. + Node* usage_node = g->FindNodeId(usages[i].dst_node_id); + NodeDef replace_def = usage_node->def(); + *replace_def.mutable_input(usages[i].dst_input) = const_node->name(); + TF_ASSIGN_OR_RETURN(Node * replace_node, + ReplaceNode(g, usage_node, replace_def)); + const Edge* usage_edge; + TF_RETURN_IF_ERROR( + replace_node->input_edge(usages[i].dst_input, &usage_edge)); + g->RemoveEdge(usage_edge); + g->AddEdge(const_node, 0, replace_node, usages[i].dst_input); + + // Later entries in `usages` might have `usage_node` as dst node, but + // `usage_node` is removed. Replace such entries with `replace_node`. + for (int j = i + 1; j < usages.size(); j++) { + if (usages[j].dst_node_id == usages[i].dst_node_id) { + usages[j].dst_node_id = replace_node->id(); + } + } + } + } + return Status::OK(); +} + +// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites +// the function to replace _Arg nodes in `const_input_index_to_node` with Const +// inputs. +Status PropagateConstIntoFuncAttr( + Node* n, const string& attr_name, + const std::unordered_map& const_input_index_to_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // Instantiate the function. + NameAttrList func_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr)); + const FunctionDef* fdef = lookup_fld->Find(func_attr.name()); + if (!fdef) { + return errors::Internal("Cannot find function ", func_attr.name(), + " for node ", n->name()); + } + FunctionBody* fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *fdef, AttrSlice(&func_attr.attr()), lookup_fld, + [lookup_fld](const string& op, const OpDef** sig) { + return lookup_fld->LookUpOpDef(op, sig); + }, + &fbody)); + std::unique_ptr fbody_deleter(fbody); + + // Rewrite _Arg usages with Const node. + Graph* func_graph = fbody->graph; + TF_RETURN_IF_ERROR( + ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node)); + + // Save rewritten function. + FunctionDef replace_fdef; + string new_func_name = + fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_")); + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef)); + + // Change the node to use rewritten function. + func_attr.set_name(new_func_name); + n->ClearAttr(attr_name); + n->AddAttr(attr_name, func_attr); + + // Copy associated functions. + TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld)); + + return Status::OK(); +} + +// For an "If" node in graph `g`, if it has Const node inputs, rewrite its +// then/else branch function to replace _Arg nodes with those Const inputs. +Status PropagateConstIntoIfNode(Graph* g, Node* if_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // Notice that first input for If node is predicate; other inputs are function + // inputs. + std::unordered_map const_input_index_to_node; + for (int i = 1; i < if_node->num_inputs(); i++) { + const Node* input_node; + TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node)); + if (input_node->type_string() == "Const") { + const_input_index_to_node[i - 1] = input_node; + } + } + if (const_input_index_to_node.empty()) { + return Status::OK(); + } + + // Rewrite "then_branch" and "else_branch" function, replace usage of those + // _Arg nodes with corresponding const node. + for (const auto& attr_name : + std::vector{"then_branch", "else_branch"}) { + TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( + if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); + } + + return Status::OK(); +} + +// For a "While" node in graph `g`, if it has Const node inputs, rewrite its +// cond/body function to replace _Arg nodes with those Const inputs. +Status PropagateConstIntoWhileNode(Graph* g, Node* while_node, + const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + // For "While" node, we should only replace _Arg nodes which are loop + // invariants. For such _Arg nodes, the return value's input will come + // directly from the corresponding arg. + std::unordered_map const_input_index_to_node; + NameAttrList body_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr)); + const FunctionDef* body_func = lookup_fld->Find(body_attr.name()); + if (!body_func) { + return errors::Internal("Cannot find body function ", body_attr.name(), + " for While node ", while_node->name()); + } + for (int i = 0; i < while_node->num_inputs(); i++) { + const Node* input_node; + TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node)); + if (input_node->type_string() != "Const") { + continue; + } + + // Check if i-th retval's input comes from i-th arg directly. + const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i); + auto output_arg_input = body_func->ret().find(output_arg.name()); + if (output_arg_input == body_func->ret().end()) { + return errors::Internal("Cannot find input for output arg ", + output_arg.name(), " in function ", + body_attr.name()); + } + const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i); + if (output_arg_input->second != input_arg.name()) { + continue; + } + + const_input_index_to_node[i] = input_node; + } + if (const_input_index_to_node.empty()) { + return Status::OK(); + } + + // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with + // corresponding const node. + for (const auto& attr_name : std::vector{"cond", "body"}) { + TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( + while_node, attr_name, const_input_index_to_node, lookup_fld, fld)); + } + return Status::OK(); +} + } // namespace const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; @@ -294,7 +510,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } -void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, +void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef) { for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { if (constraint.name() == name) { @@ -330,8 +546,8 @@ uint32 GetXLARandomSeed() { // TODO(b/77601805): add tests for associated function related stuff. bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr) { - if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + const FunctionLibraryDefinition* fld) { + if (fld->Contains(node_def.op())) { return true; } @@ -351,10 +567,10 @@ bool HasAssociatedFunction(const NodeDef& node_def, } std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr) { + const Node& node, const FunctionLibraryDefinition* fld) { std::vector results; const string& op = node.type_string(); - if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + if (fld->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); @@ -441,4 +657,97 @@ Status RewriteAssociatedFunction( return Status::OK(); } +Status CachedFunctionHandles::GetOrInstantiate( + const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle) { + string canonicalized_name = Canonicalize(func_name, attrs); + auto iter = handles_.find(canonicalized_name); + if (iter != handles_.end()) { + *handle = iter->second; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle)); + handles_[canonicalized_name] = *handle; + return Status::OK(); +} + +Status CachedFunctionHandles::ReleaseAllHandles() { + Status result; + for (auto iter : handles_) { + result.Update(flr_->ReleaseHandle(iter.second)); + } + handles_.clear(); + return result; +} + +xla::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) { + // Create the replacement node. + Status s; + Node* new_node = g->AddNode(node_def, &s); + if (!s.ok()) { + return s; + } + + // Record original node's output edges and remove them first. This is to avoid + // multiple producers for dst nodes' input. + std::vector out_edge_info; + std::vector out_edges; + for (const Edge* edge : n->out_edges()) { + out_edges.push_back(edge); + out_edge_info.push_back( + {edge->dst(), edge->src_output(), edge->dst_input()}); + } + for (const Edge* edge : out_edges) { + g->RemoveEdge(edge); + } + + // Add original node's input and output edges to the replacement node. + for (const Edge* in_edge : n->in_edges()) { + g->AddEdge(in_edge->src(), in_edge->src_output(), new_node, + in_edge->dst_input()); + } + for (const OutEdgeInfo& out_edge : out_edge_info) { + g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input); + } + + // Remove the original node. + g->RemoveNode(n); + + return new_node; +} + +xla::StatusOr BuildIdentityNode( + Graph* graph, const string& node_name, DataType dtype, const Node* input, + absl::optional requested_device) { + // Create identity node. + NodeDef ndef; + ndef.set_name(node_name); + ndef.set_op("Identity"); + if (input) { + ndef.add_input(input->name()); + } + if (requested_device) { + ndef.set_device(*requested_device); + } + AddNodeAttr("T", dtype, &ndef); + Status s; + Node* id_node = graph->AddNode(ndef, &s); + TF_RETURN_IF_ERROR(s); + return id_node; +} + +Status PropagateConstIntoFunctionalNodes( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { + for (Node* n : g->op_nodes()) { + if (n->type_string() == "If") { + TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld)); + } else if (n->type_string() == "While") { + TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld)); + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 53eab8b63e2fc8aa3dfb0bacfe065897ca775bd0..cf3aa2f847c5ada8897110c7735b207f388f88d4 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" @@ -54,7 +55,7 @@ string TensorIdToString(const tf2xla::TensorId& id); Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. -void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, +void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. @@ -120,7 +121,7 @@ class AssociatedFunctionInfo { // Returns if the NodeDef has associated function. bool HasAssociatedFunction(const NodeDef& node_def, - FunctionLibraryRuntime* flr); + const FunctionLibraryDefinition* fld); // Gets functions associated with the node. Current cases: // 1. For function call node, its function name; @@ -128,7 +129,7 @@ bool HasAssociatedFunction(const NodeDef& node_def, // and returned attrs will be this node's attributes; // 3. For nodes like XlaWhile/XlaIf, all their function attributes. std::vector GetAssociatedFunctions( - const Node& node, FunctionLibraryRuntime* flr); + const Node& node, const FunctionLibraryDefinition* fld); // Changes associated functions for the node. Current cases: // 1. For function call node, creates a new node with the new function name and @@ -144,6 +145,58 @@ Status RewriteAssociatedFunction( // Attribute to mark nodes to be executed on host. extern const char kXlaOutsideCompilationAttrName[]; +// Class to act as cache for FunctionLibraryRuntime::Handle objects. +class CachedFunctionHandles { + public: + CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {} + + // Populates `handle` for requested function and attributes. If we have + // instantiated the function with the same attributes before, `handle` will be + // cached handle; otherwise instantiate the function and populate `handle`. + Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); + + // Releases all handles in the cache. Returns first non-OK status if any; + // returns OK otherwise. + Status ReleaseAllHandles(); + + ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); } + + private: + FunctionLibraryRuntime* flr_; + std::map handles_; + + TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles); +}; + +// Struct for node's output edge info. +struct OutEdgeInfo { + Node* dst; + int src_output, dst_input; +}; + +// Replaces node `n` with a new node whose NodeDef is `node_def`. +xla::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def); + +// Helper function that builds an Identity node. +xla::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, + DataType dtype, const Node* input, + absl::optional requested_device); + +// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite +// body functions to use the Const nodes instead of original _Arg nodes. +// +// For example, say we have the following computation: +// shape = constant_op.constant([1]) +// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape)) +// If we do not rewrite then/else function, they will use _Arg node as shape +// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile +// time constant, so XLA compilation will fail. This rewriting process will +// change the shape input to Const node. +Status PropagateConstIntoFunctionalNodes( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 68441b3d4790b17bd06accff3fcdc8ccee79bbb7..202e929315cacd4d6cdfc69d50639d8a427ec6c2 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -23,11 +23,15 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -255,5 +259,75 @@ TEST(SetNodeShardingFromNeighbors, Basic) { EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); } +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +// Tests that CachedFunctionHandles class works. +TEST(CachedFunctionHandles, Basic) { + FunctionDef func = FunctionDefHelper::Define( + // Name + "TestFunc", + // Args + {}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "One", {}, {{"T", "$T"}}}, + }); + FunctionDefLibrary proto; + *proto.add_function() = func; + FunctionLibraryDefinition fld(OpRegistry::Global(), proto); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld, + OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + CachedFunctionHandles cached_function_handles(flr); + + // Tests that GetOrInstantiate() works. + FunctionLibraryRuntime::Handle first_handle; + AttrValue attr; + attr.set_type(DT_FLOAT); + AttrValueMap attrs; + attrs["T"] = attr; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &first_handle)); + + // Tests that we can get FunctionBody. + const FunctionBody* body = flr->GetFunctionBody(first_handle); + EXPECT_NE(body, nullptr); + + // Tests that GetOrInstantiate() returns cached handle when called with same + // function name and attributes. + FunctionLibraryRuntime::Handle second_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &second_handle)); + EXPECT_EQ(first_handle, second_handle); + + // Tests that GetOrInstantiate() returns new handle when called with same + // function name but different attributes. + attr.set_type(DT_INT32); + attrs["T"] = attr; + FunctionLibraryRuntime::Handle third_handle; + TF_ASSERT_OK(cached_function_handles.GetOrInstantiate( + "TestFunc", AttrSlice(&attrs), &third_handle)); + EXPECT_NE(first_handle, third_handle); + + // Tests that ReleaseAllHandles() works. + TF_EXPECT_OK(cached_function_handles.ReleaseAllHandles()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 7f860500c75667a920505dbf498e3da4b388fb90..ddb284966eeb97cc7c9d3ed77fb313e567975e59 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -92,7 +92,7 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) { void XlaCompilationDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(4) << "XlaCompilationDevice::Compute " - << SummarizeNodeDef(op_kernel->def()); + << FormatNodeDefForError(op_kernel->def()); auto* b = XlaContext::Get(context).builder(); xla::OpMetadata metadata; metadata.set_op_type(op_kernel->type_string()); @@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto( "XLACompilationDevice::MakeTensorFromProto should not be called"); } -XlaExpression::XlaExpression() = default; - -void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index a6e78825334fec748be5fee80669649df699d2fb..de6a3356e05d8ab45c269d7c6c653853d2c63a79 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" @@ -38,8 +35,8 @@ class XlaCompilationAllocator; // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that -// backs each Tensor with metadata indicating the computation the -// Tensor represents. +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. // // We deliberately don't register a device factory because we *never* // want placement to put Ops on a compilation device. The device is created @@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// A XlaExpression wraps an XLA computation. Each Tensor on an -// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the XlaOp. Each -// expression is either a constant, or a function of previously-compiled -// expressions. -class XlaExpression { - public: - XlaExpression(); - - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::XlaOp& h); - const xla::XlaOp& handle() const { return handle_; } - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - void set_resource(XlaResource* resource) { resource_ = resource; } - XlaResource* resource() const { return resource_; } - - private: - // The XLA handle of the expression's computation. - xla::XlaOp handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_ = false; - Tensor constant_value_; - - XlaResource* resource_ = nullptr; // Not owned. -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 425e769346ffcbc548495d93cb7adc779f860110..c7341cf8b9e8d7a06fd304ae8766420d20f0c16e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -26,7 +26,7 @@ limitations under the License. // Forward-declare, rather than include, to reduce code size for users that // never use this functionality. namespace xla { -class ProgramShape; +class ProgramShapeProto; class HloProfilePrinterData; } @@ -84,7 +84,7 @@ class XlaCompiledCpuFunction { void set_result_names(const char** result_names) { result_names_ = result_names; } - void set_program_shape(const xla::ProgramShape* program_shape) { + void set_program_shape(const xla::ProgramShapeProto* program_shape) { program_shape_ = program_shape; } const xla::HloProfilePrinterData* hlo_profile_printer_data() const { @@ -122,7 +122,7 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; @@ -206,8 +206,14 @@ class XlaCompiledCpuFunction { // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. - void set_arg_data(size_t index, void* data) { - buffer_table_[arg_index_table_[index]] = data; + void set_arg_data(size_t index, const void* data) { + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); } // ------------------------------ @@ -264,7 +270,7 @@ class XlaCompiledCpuFunction { // Returns the shape of the args and results. May return nullptr if the // program shape isn't available. - const xla::ProgramShape* ProgramShape() const { return program_shape_; } + const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } bool hlo_profiling_enabled() const { return hlo_profile_printer_data_ != nullptr; @@ -287,11 +293,6 @@ class XlaCompiledCpuFunction { // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - // - // For now we need to keep around the args_ array because there is code that - // depends on args() returning a void**. However, in the future we may remove - // args_ in favor of using buffer_table_ as the sole storage for the - // arguments. const int32* const arg_index_table_; // The number of incoming arguments. @@ -310,7 +311,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index b2c57e88803e0661a9a514f844dff97ff9edf2ea..ee461a3c07d4db514c7697e005a9371be4b54dd0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,10 +36,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -48,7 +51,7 @@ namespace { // Checks that arguments `args` match types `types`. Status CheckSignature(const DataTypeVector& types, - const std::vector& args) { + absl::Span args) { if (args.size() != types.size()) { return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); @@ -63,14 +66,270 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } +// Uses the _Arg and _Retval nodes in the graph to determine a core assignment +// for each argument and return value. +xla::StatusOr, std::map>> +ComputeArgAndRetvalCores(const Graph& graph) { + auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + std::map arg_cores; + std::map retval_cores; + for (const Node* n : graph.nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Arg index"; + arg_cores[index] = core; + } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Retval index"; + TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); + retval_cores[index] = core; + } + } + return std::make_pair(std::move(arg_cores), std::move(retval_cores)); +} + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status status; + auto step_container = absl::make_unique( + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return Status::OK(); +} + +// Builds the XLA computation. +// - `args` is the list of input arguments +// - `retvals` is the list of retvals produced by _Retval operators, in index +// order. +// - `args_core` and `retval_cores` are mapping from arg/return indices to core +// assignments. +// - If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// - Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// - `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a ResourceUpdate, whose `index` is the index of a +// resource variable argument to the computation to be updated, and `type` is +// the type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& retvals, + const std::map& arg_cores, const std::map& retval_cores, + const std::vector>& resources, + std::unique_ptr token_output, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, + std::vector* resource_updates, + xla::Shape* output_shape) { + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); + + // Builds a no-op XLA computation. We need to set the sharding of outputs, but + // cannot change the sharding of the existing output op. To do this, we build + // a new identity op to which shardings can be applied. + auto identity_op = [builder](xla::XlaOp op) { + return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); + }; + + std::vector elems; + elems.reserve(retvals.size()); + + // Keeps track of which retvals have layout to update. The first element is + // the output index, second element is the new layout. + std::vector> retval_to_update_layout; + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + const XlaExpression& retval = retvals[i]; + output.type = retval.dtype(); + switch (retval.kind()) { + case XlaExpression::Kind::kConstant: + output.is_constant = true; + output.constant_value = retval.constant_value(); + output.shape = output.constant_value.shape(); + break; + + case XlaExpression::Kind::kXlaOp: { + output.is_constant = false; + TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); + xla::XlaOp value = retval.handle(); + auto it = retval_cores.find(i); + xla::XlaScopedShardingAssignment assign_sharding( + builder, it == retval_cores.end() + ? absl::optional() + : xla::sharding_builder::AssignDevice(it->second)); + if (shape_representation_fn) { + // If there is a shape representation function, reshape the output + // tensor to the shape given by the representation shape function. + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( + output.shape, output.type)); + value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + retval_to_update_layout.emplace_back(elems.size(), shape.layout()); + } else if (it != retval_cores.end()) { + // Apply the sharding to the output, if there is a core assignment. + value = identity_op(value); + } + + elems.push_back(value); + break; + } + + case XlaExpression::Kind::kResource: + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); + output.shape = retval.resource()->shape(); + break; + + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument( + "Invalid expression returned by computation. " + "This probably means a return value was not set."); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + DCHECK_LT(resource->arg_num(), args.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + auto it = arg_cores.find(resource->arg_num()); + const int core = it == arg_cores.end() ? -1 : it->second; + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::XlaScopedShardingAssignment assign_sharding( + builder, core == -1 ? absl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Ensures the correct sharding is applied to the output. + handle = identity_op(handle); + + elems.push_back(handle); + } + } + + // If we have token output, append it as the last one. + if (token_output) { + elems.push_back(*token_output); + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + + TF_ASSIGN_OR_RETURN(const auto& program_shape, + computation->GetProgramShape()); + *output_shape = program_shape.result(); + // Update the output layout to the layout of retval. + for (auto& update : retval_to_update_layout) { + if (!always_return_tuple && elems.size() == 1) { + *output_shape->mutable_layout() = update.second; + continue; + } + + xla::Shape* output_sub_shape = + xla::ShapeUtil::GetMutableSubshape(output_shape, {update.first}); + *output_sub_shape->mutable_layout() = update.second; + } + return Status::OK(); +} + } // namespace bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.initialized, other.tensor_array_size, + other.initialized, other.max_array_size, other.tensor_array_gradients)) { return false; } @@ -83,12 +342,45 @@ bool XlaCompiler::Argument::operator==( return constant_value.tensor_data() == other.constant_value.tensor_data(); } +string XlaCompiler::Argument::HumanString() const { + string common; + if (!name.empty()) { + common = absl::StrCat(" name=", name); + } + absl::StrAppend(&common, " type=", DataTypeString(type), + " shape=", shape.DebugString()); + switch (kind) { + case kInvalid: + return "invalid"; + case kConstant: + return absl::StrCat("kind=constant", common, + " value=", constant_value.DebugString()); + case kResource: { + string output = absl::StrCat("kind=resource", common, " resource_kind=", + XlaResource::KindToString(resource_kind), + " initialized=", initialized); + if (max_array_size >= 0) { + absl::StrAppend(&output, " max_array_size=", max_array_size); + } + if (!tensor_array_gradients.empty()) { + absl::StrAppend(&output, " tensor_array_gradients=", + absl::StrJoin(tensor_array_gradients, ",")); + } + return output; + } + case kParameter: + return absl::StrCat("kind=parameter", common); + case kToken: + return absl::StrCat("token", common); + } +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), - device_mgr_({device_}) { + device_mgr_(absl::WrapUnique(device_)) { CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = @@ -110,8 +402,13 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = [](const TensorShape& shape, - DataType type) { return shape; }; + options_.shape_representation_fn = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; } } @@ -171,15 +468,16 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - std::vector args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction( + const XlaCompiler::CompileOptions& options, const NameAttrList& function, + absl::Span args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; - auto it = cache_.find({function_id, args}); + const std::vector arg_vector(args.begin(), args.end()); + auto it = cache_.find({function_id, arg_vector}); if (it != cache_.end()) { *result = it->second; return Status::OK(); @@ -212,14 +510,16 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kArgOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kRetOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -235,7 +535,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; - cache_[{function_id, args}] = *result; + cache_[{function_id, arg_vector}] = *result; return Status::OK(); } @@ -247,33 +547,32 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape; if (is_entry_computation) { TF_ASSIGN_OR_RETURN( - shape, options_.shape_representation_fn(arg.shape, arg.type)); + *xla_shape, options_.shape_representation_fn(arg.shape, arg.type)); } else { - shape = arg.shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, arg.shape, xla_shape)); } - return TensorShapeToXLAShape(arg.type, shape, xla_shape); + return Status::OK(); } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { - TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, - options_.shape_representation_fn(arg.shape, arg.type)); - return TensorShapeToXLAShape(arg.type, representation_shape, - xla_shape); + TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( + arg.shape, arg.type)); + + return Status::OK(); } case XlaResource::kTensorArray: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); @@ -285,12 +584,12 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, return Status::OK(); } case XlaResource::kStack: { - if (arg.tensor_array_size < 0) { + if (arg.max_array_size < 0) { return errors::InvalidArgument( - "Negative tensor_array_size in XLAShapeForArgument"); + "Negative max_array_size in XLAShapeForArgument"); } TensorShape shape; - shape.AddDim(arg.tensor_array_size); + shape.AddDim(arg.max_array_size); shape.AppendShape(arg.shape); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( @@ -314,169 +613,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } -namespace { - -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the - // resource manager takes ownership via Create, and unrefs via Cleanup. We - // explicitly add a reference to ensure the refcount at entry is maintained at - // all exit points; Create and Cleanup are always called in this function. - // - // The Executor requires us to use ScopedStepContainer. We wrap it in a - // unique_ptr so we can capture the cleanup status in the end. - xla_context->Ref(); - Status status; - auto step_container = absl::make_unique( - step_id, [&status, device](const string& name) { - status = device->resource_manager()->Cleanup(name); - }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); - - GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); - TF_RETURN_IF_ERROR(graph_compiler.Compile()); - // Explicitly clean up the step container, to capture the cleanup status. - step_container.reset(); - return Status::OK(); -} - -// Builds the XLA computation. -// `args` is the list of input arguments, `retvals` is the list of retvals -// produced by _Retval operators, in index order. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else if (retval.resource() != nullptr) { - output.is_constant = false; - output.input_index = retval.resource()->arg_num(); - } else { - output.is_constant = false; - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - // Attach a common operator name as metadata. This has no semantic effect — it - // merely makes the HLO graph more readable when visualized via TensorBoard, - // since TensorBoard forms groups out of operators with similar names. - xla::OpMetadata retval_metadata; - retval_metadata.set_op_name("XLA_Retvals"); - builder->SetOpMetadata(retval_metadata); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = - modified || - !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); - } - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. We *always* form a tuple here to ensure that - // the output value is the last thing added into the XLA computation, even - // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); - if (!always_return_tuple && elems.size() == 1) { - xla::GetTupleElement(tuple, 0); - } - builder->ClearOpMetadata(); - - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - std::vector* arg_cores, std::vector* arg_expressions, + const std::map& arg_cores, + std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); - *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -489,28 +635,30 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[i]; XlaExpression& arg_expression = (*arg_expressions)[i]; switch (arg.kind) { - case XlaCompiler::Argument::kResource: + case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); // TODO(phawkins): this code assumes that resource arguments do not // alias. - XlaResource* resource; - TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), - /*tensor_array_size=*/arg.tensor_array_size, - /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); - arg_expression.set_resource(resource); + XlaResource* resource = + context->AddResource(absl::make_unique( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::XlaOp(), + /*max_array_size=*/arg.max_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, + /*tensor_array_multiple_writes_aggregate=*/true)); + arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } - break; + } case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } case XlaCompiler::Argument::kConstant: - arg_expression.set_constant_value(arg.constant_value); + arg_expression = XlaExpression::Constant(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -535,26 +683,6 @@ Status XlaCompiler::BuildArguments( *input_shapes = arg_shapes; } - // Use the _Arg nodes in the graph to resolve core assignments. - for (const Node* n : graph.nodes()) { - if (absl::string_view(n->type_string()) != "_Arg") continue; - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - TF_RET_CHECK(index >= 0 && index < args.size()) - << "_Arg out of bounds: " << index << " vs " << args.size(); - TF_ASSIGN_OR_RETURN( - auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == - xla::OpSharding::Type::OpSharding_Type_MAXIMAL); - const int core = sharding.value().tile_assignment_devices(0); - if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { - (*arg_cores)[index] = core; - } - } - } - // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -570,11 +698,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); for (int64 parameter : *input_mapping) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; + auto it = arg_cores.find(parameter); + const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); + xla::sharding_builder::AssignDevice(core); } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); @@ -583,7 +710,8 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -591,7 +719,8 @@ Status XlaCompiler::BuildArguments( } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -626,14 +755,14 @@ Status XlaCompiler::BuildArguments( // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. if (is_entry_computation) { - arg_expression.set_handle( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression = XlaExpression::XlaOp( + xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); } else { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; case XlaCompiler::Argument::kToken: { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; } case XlaCompiler::Argument::kConstant: @@ -647,46 +776,48 @@ Status XlaCompiler::BuildArguments( } Status XlaCompiler::CompileSingleOp( - const XlaCompiler::CompileOptions& options, string const& name, - OpKernelContext* ctx, const std::vector& args, - CompilationResult* result) { + const XlaCompiler::CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result) { // TODO(b/74182462): We implement this by creating a new dummy Graph including // _Arg nodes, and let CompileGraph walk it. This could be optimized. std::unique_ptr graph(new Graph(OpRegistry::Global())); Status status; // First create the actual node we care about computing. - Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + Node* main_node = graph->AddNode(node_def, &status); TF_RETURN_IF_ERROR(status); // Create dummy _Arg nodes. Link these to `node` and also via a control // dependency edge to the _SOURCE node. - for (int64 i = 0; i < ctx->num_inputs(); ++i) { + for (int64 i = 0; i < args.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); - Status status = NodeBuilder(name, "_Arg") - .ControlInput(graph->source_node()) - .Attr("T", ctx->input_dtype(i)) - .Attr("index", i) - .Finalize(graph.get(), &node); + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); graph->AddEdge(node, 0, main_node, i); } // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < ctx->num_outputs(); ++i) { + for (int64 i = 0; i < result_types.size(); ++i) { Node* node; - string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); - Status status = NodeBuilder(name, "_Retval") + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) .Input(main_node, i) - .Attr("T", ctx->expected_output_dtype(i)) + .Attr("T", result_types[i]) .Attr("index", i) .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } FixupSourceAndSinkEdges(graph.get()); - return CompileGraph(options, name, std::move(graph), args, result); + return CompileGraph(options, node_def.name(), std::move(graph), args, result); } namespace { @@ -741,15 +872,43 @@ Status ValidateGraph(const Graph* graph, return Status::OK(); } +// Converts the value of any expressions whose values are known at compile-time +// to constants. +Status ResolveConstantExpressionsToConstants( + xla::Client* client, absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kXlaOp) { + TF_ASSIGN_OR_RETURN(absl::optional constant, + expression.ResolveConstant(client)); + if (constant.has_value()) { + expression = XlaExpression::Constant(*constant); + } + } + } + return Status::OK(); +} + +void ConvertConstantsToExpressions(xla::XlaBuilder* builder, + absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kConstant) { + expression = + XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); + } + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; + TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes( + graph.get(), options_.flib_def, local_flib_def_.get())); if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( @@ -766,14 +925,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = new XlaContext(this, &builder); core::ScopedUnref context_unref(context); - std::vector real_args(args); + std::vector real_args(args.begin(), args.end()); int token_input_index = -1; + std::unique_ptr token_output; if (options.add_token_input_output) { // Add extra token input. token_input_index = real_args.size(); @@ -783,10 +940,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, real_args.push_back(token_arg); } + std::map arg_cores; + std::map retval_cores; + TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), + ComputeArgAndRetvalCores(*graph)); + std::vector arg_expressions; - std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -826,8 +987,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, TF_RETURN_IF_ERROR(token_or.status()); token_inputs.push_back(token_or.ValueOrDie()); } - TF_RETURN_IF_ERROR( - context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + token_output.reset(new xla::XlaOp(xla::AfterAll(&builder, token_inputs))); } TF_RETURN_IF_ERROR(PopNodeTokenMapping()); @@ -835,28 +995,27 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); + std::vector retvals = context->retvals(); + if (options.resolve_compile_time_constants) { + TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals))); + } else { + ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + } TF_RETURN_IF_ERROR(BuildComputation( - real_args, arg_cores, context->retvals(), context->resources(), + real_args, retvals, arg_cores, retval_cores, context->resources(), + std::move(token_output), + options.is_entry_computation ? options_.shape_representation_fn + : ShapeRepresentationFn{}, options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, - &result->resource_updates)); + &result->resource_updates, &result->xla_output_shape)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - - // Compute the XLA output shape, if there is a computation with non-constant - // outputs. - TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, - client()->GetComputationShape(*result->computation)); - - result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " - << xla::ShapeUtil::HumanString(result->xla_output_shape); - - // Tensorflow expects a major-to-minor order of results. - xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - + << xla::ShapeUtil::HumanStringWithLayout(result->xla_output_shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2cc603a58016a509fafdf6f95423dd6c0864cce3..0d801b73a8c2651305328384377751254ecaa41d 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,10 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" @@ -118,7 +121,7 @@ class XlaCompiler { // The type of the argument. If the argument is a resource, this // is the type of the variable's value, not DT_RESOURCE. - DataType type; + DataType type = DT_INVALID; // The shape of the argument. For: // * a parameter: the shape of the parameter. @@ -147,7 +150,7 @@ class XlaCompiler { // For a TensorArray or Stack resource, what is the array's declared size? // (Used for lazy initialization.) - int64 tensor_array_size = -1; + int64 max_array_size = -1; // TensorArray resource parameters are passed as (array, gradient array 0, // ..., gradient array k), where the gradient arrays are in the same order @@ -155,6 +158,9 @@ class XlaCompiler { std::set tensor_array_gradients; bool operator==(const Argument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; }; // Options pertaining to an individual call to CompileGraph() or @@ -259,8 +265,7 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function(const TensorShape&, - DataType)> + typedef std::function(const TensorShape&, DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. @@ -316,22 +321,23 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - std::vector args, CompilationResult* result); + absl::Span args, + CompilationResult* result); // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, std::unique_ptr graph, - const std::vector& args, + absl::Span args, CompilationResult* result); - // Compiles a single Op, given by an OpKernelContext, into an + // Compiles a single Op, given by `node_def`, into an // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. - Status CompileSingleOp(const CompileOptions& options, string const& name, - OpKernelContext* ctx, - const std::vector& args, + Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def, + absl::Span args, + absl::Span result_types, CompilationResult* result); // Returns the shape of the XLA parameter for an argument 'arg'. @@ -411,7 +417,8 @@ class XlaCompiler { Status BuildArguments(const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, std::vector* arg_cores, + XlaContext* context, + const std::map& arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 72b17d04fc42eb00781e96b412465b73fb29a5c2..fe2a5f5b0c9ea6b5f2bb71df836fdcabf9a0cf23 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -354,8 +356,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { EXPECT_TRUE( absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); - EXPECT_TRUE( - absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}")) + << status.error_message(); + EXPECT_TRUE(absl::StrContains(status.error_message(), + "must be a compile-time constant")) << status.error_message(); } @@ -646,7 +650,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; // Compiles the graph. @@ -705,7 +709,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -737,7 +741,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].initialized = true; args[0].type = DT_INT32; args[0].shape = TensorShape({}); - args[0].tensor_array_size = 2; + args[0].max_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; // Compiles the graph. @@ -907,6 +911,82 @@ TEST_F(XlaCompilerTest, Variables) { RunAndCheckVariablesComputation(client_, result); } +TEST_F(XlaCompilerTest, ResultLayoutSingle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET"), a, 0); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + auto compile_options = XlaCompiler::CompileOptions(); + compile_options.always_return_tuple = false; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph), + args, &result)); + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); +} + +TEST_F(XlaCompilerTest, ResultLayoutMultiple) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0); + auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 3}); + + auto options = DefaultOptions(); + // Sets the representation function to return a non-default layout. + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); + return xla_shape; + }; + + // Compiles the graph. + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id", + std::move(graph), args, &result)); + xla::Shape result_shape = + xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}); + + EXPECT_TRUE(xla::ShapeUtil::Equal( + result.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -1016,9 +1096,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); @@ -1084,9 +1166,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); + options.shape_representation_fn = + [](const TensorShape& shape, DataType type) -> xla::StatusOr { + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); + return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); }; XlaCompiler compiler(options); @@ -1256,23 +1340,30 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { TF_ASSERT_OK(status); EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); - const std::vector empty_args; + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kVariable; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + { // The case for entry computation: we don't add token input/output. Instead, // we use CreateToken HLO to create the entry token. XlaCompiler::CompileOptions options; options.is_entry_computation = true; options.add_token_input_output = false; + options.return_updated_values_for_all_resources = true; XlaCompiler compiler(DefaultOptions()); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - empty_args, &result)); - EXPECT_EQ(result.xla_input_shapes.size(), 0); + args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); - EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); } { // The case for non-entry computation (e.g. while loop body). We add token @@ -1280,19 +1371,20 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { XlaCompiler::CompileOptions options; options.is_entry_computation = false; options.add_token_input_output = true; + options.return_updated_values_for_all_resources = true; XlaCompiler compiler(DefaultOptions()); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), - empty_args, &result)); - EXPECT_EQ(result.xla_input_shapes.size(), 1); - EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 2); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[1])); EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); - EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2); EXPECT_TRUE(xla::ShapeUtil::IsToken( - xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1))); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index f247570d72c0287a33695de3d778cce2a2418921..a69af70503376b6c0905deb8980abdc3254a6e47 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -54,98 +54,25 @@ const char XlaContext::kXlaContextResourceName[] = "_xla_context"; return *context; } -/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) { - return Get(ctx->op_kernel_context()); -} - void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext( - XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn) - : compiler_(compiler), - builder_(builder), - allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), - shape_representation_fn_(shape_representation_fn) {} - -string XlaContext::DebugString() { return "TLA JIT context"; } - -// This is called by the Retval Op to associate a computed value -// with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { - VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; -} +XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder) + : compiler_(compiler), builder_(builder) {} -Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal) { - VLOG(1) << "Adding retval index " << retval_index - << " with non-data-dependent tensor to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; - return Status::OK(); -} +string XlaContext::DebugString() { return "XLA JIT context"; } -Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { - VLOG(1) << "Adding retval index " << retval_index << " with resource " - << resource->name() << ":" << resource->shape().DebugString() - << " to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); +void XlaContext::SetRetval(int index, const XlaExpression& expression) { + if (retvals_.size() <= index) { + retvals_.resize(index + 1); } - XlaExpression e; - e.set_resource(resource); - retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; - return Status::OK(); -} - -Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { - VLOG(1) << "Adding retval index " << retvals_.size() - << " with token to XLA computation"; - XlaExpression e; - e.set_handle(token); - // We use DT_INVALID because there is no TF DataType which corresponds to XLA - // token. XlaCompiler handles this case separately, so putting it here is OK. - retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); - return Status::OK(); -} - -xla::XlaBuilder* XlaContext::builder() { return builder_; } - -Status XlaContext::CreateResource( - XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, XlaResource** resource) { - resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), - handle, tensor_array_size, tensor_array_gradients)); - *resource = resources_.back().get(); - return Status::OK(); + retvals_[index] = expression; } -xla::StatusOr XlaContext::RepresentationShape( - const TensorShape& shape, DataType type) const { - return (*shape_representation_fn_)(shape, type); +XlaResource* XlaContext::AddResource(std::unique_ptr resource) { + resources_.push_back(std::move(resource)); + return resources_.back().get(); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index d7dbdc957f0e7969db5098b815381866cdc71ab6..0767d1faac14cedb8666f6cc37175eb7b55f6158 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -41,15 +41,10 @@ class XlaContext : public ResourceBase { public: // Retrieves the XlaContext of the current compilation. static XlaContext& Get(const OpKernelContext* ctx); - static XlaContext& Get(const XlaOpKernelContext* ctx); // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. - XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, - const std::function( - const TensorShape&, DataType)>* shape_representation_fn); + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder); // Virtual method defined by ResourceBase. string DebugString() override; @@ -57,60 +52,25 @@ class XlaContext : public ResourceBase { XlaCompiler* compiler() const { return compiler_; } // Returns the XlaBuilder that Ops use for compiling new expressions. - xla::XlaBuilder* builder(); - - bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } + xla::XlaBuilder* builder() { return builder_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector& retvals() { return retvals_; } - - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); - - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal); - - // As for Retval, but for return values that are resource handles. - Status AddResourceRetval(int retval_index, XlaResource* resource); - - // As for Retval, but for return values that are XLA tokens. - Status AppendTokenRetval(const xla::XlaOp& token); - - // Creates a resource with resource `kind` and initial value `handle`. `name` - // is a descriptive name for use in error messages. See the `XlaResource` - // constructor for a description of the remaining arguments. - // Fails if the resource already exists. - Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, TensorShape shape, - const xla::XlaOp& handle, int64 tensor_array_size, - const std::set& tensor_array_gradients, - XlaResource** resource); + const std::vector& retvals() { return retvals_; } + + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); + + // Adds 'resource' to the set of resources owned by the context. + XlaResource* AddResource(std::unique_ptr resource); const std::vector>& resources() { return resources_; } - // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`, or of an argument or return value of a top-level computation. - xla::StatusOr RepresentationShape(const TensorShape& shape, - DataType type) const; - // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -140,36 +100,16 @@ class XlaContext : public ResourceBase { // The XlaBuilder used to construct the subgraph's compiled representation. xla::XlaBuilder* builder_; - // Allow ops to emit CustomCall operations for CPU. - const bool allow_cpu_custom_calls_; - - // If true, constant return values are returned as Tensors instead of - // run-time computation outputs. - const bool resolve_compile_time_constants_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - - // A function that describes how the shapes of - // a) argument and return value, for entry computations - // b) variables, for all computations, - // should be represented in XLA. Parameters/return values will be shaped - // according to this function, and reshaped back to/from their declared shapes - // for computations. Must be non-null. - const std::function(const TensorShape&, DataType)>* - shape_representation_fn_; - // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index bc44301d405102921de21da4bd9407032783838c..9bb785842d061e5892ba9da0a902eef50d21f55d 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -21,10 +21,10 @@ namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { if (kdef->op() == "Const") { - AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + AddDtypeToKernelDefConstraint("dtype", DT_STRING, kdef); } if (kdef->op() == "Assert") { - AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + AddDtypeToKernelDefConstraint("T", DT_STRING, kdef); } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca0309166b7c73d1a5a818091e2a30fa112a4de4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_expression.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() = default; + +XlaExpression XlaExpression::Invalid() { + XlaExpression e; + e.kind_ = Kind::kInvalid; + return e; +} + +XlaExpression XlaExpression::Constant(Tensor value) { + XlaExpression e; + e.kind_ = Kind::kConstant; + e.dtype_ = value.dtype(); + e.constant_value_ = value; + return e; +} + +XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { + XlaExpression e; + e.kind_ = Kind::kXlaOp; + e.dtype_ = dtype; + e.handle_ = value; + return e; +} + +XlaExpression XlaExpression::Resource(XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + return e; +} + +string XlaExpression::HumanString() const { + switch (kind_) { + case Kind::kInvalid: + return "invalid"; + case Kind::kConstant: + return "constant"; + case Kind::kXlaOp: + return "xla_op"; + case Kind::kResource: + return "resource"; + } +} + +xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + switch (kind_) { + case Kind::kConstant: { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToBorrowingLiteral(constant_value_, &literal)); + return xla::ConstantLiteral(builder, literal); + } + case Kind::kXlaOp: + if (builder != handle_.builder()) { + return errors::InvalidArgument( + "Mismatched builders in XlaExpression::AsXlaOp"); + } + return handle_; + default: + return errors::InvalidArgument("AsXlaOp called on XlaExpression: ", + HumanString()); + } + }); +} + +xla::StatusOr> XlaExpression::ResolveConstant( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: + return {constant_value()}; + case Kind::kXlaOp: + break; + case Kind::kResource: + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveConstant called on XlaExpression: ", HumanString()); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + handle().builder()->IsConstant(handle())); + if (!is_constant) return {absl::nullopt}; + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildConstantSubGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor)); + return {tensor}; +} + +xla::StatusOr XlaExpression::GetShape() const { + switch (kind_) { + case Kind::kConstant: + return constant_value().shape(); + case Kind::kXlaOp: { + TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, + handle().builder()->GetShape(handle())); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + return shape; + } + case Kind::kResource: + return TensorShape({}); + case Kind::kInvalid: + return errors::InvalidArgument( + "GetShape() called on invalid XlaExpression"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 0000000000000000000000000000000000000000..bed6761d362a98d344003c1edea342e68c31ef07 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + const Tensor& constant_value() const { return constant_value_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kConstant expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + xla::StatusOr> ResolveConstant( + xla::Client* client) const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + xla::StatusOr GetShape() const; + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp. + xla::XlaOp handle_; + + // The value of the constant, if kind_ == kConstant. + Tensor constant_value_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..84202c931390f2d68f6d381aef0752bfff00a53d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class XlaExpressionTest : public ::testing::Test { + protected: + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + builder_ = absl::make_unique("acomputation"); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); + non_constant_op_ = xla::Parameter( + builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); + resource_ = absl::make_unique( + XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), + DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), + /*tensor_array_multiple_writes_aggregate=*/false); + } + + xla::Client* client_; + std::unique_ptr builder_; + Tensor constant_; + xla::XlaOp op_; + xla::XlaOp non_constant_op_; + std::unique_ptr resource_; +}; + +TEST_F(XlaExpressionTest, Kind) { + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind()); + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind()); + EXPECT_TRUE(XlaExpression::Kind::kConstant == + XlaExpression::Constant(constant_).kind()); + EXPECT_TRUE(XlaExpression::Kind::kXlaOp == + XlaExpression::XlaOp(op_, DT_INT32).kind()); + EXPECT_TRUE(XlaExpression::Kind::kResource == + XlaExpression::Resource(resource_.get()).kind()); +} + +TEST_F(XlaExpressionTest, HumanString) { + EXPECT_EQ("invalid", XlaExpression().HumanString()); + EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString()); + EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString()); + EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString()); + EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString()); +} + +TEST_F(XlaExpressionTest, AsXlaOp) { + xla::XlaOp op_as_op = + XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get()); + EXPECT_TRUE(op_.IsIdenticalTo(op_as_op)); + + xla::XlaOp const_as_op = + XlaExpression::Constant(constant_).AsXlaOp(builder_.get()); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, + builder_->BuildConstantSubGraph(const_as_op)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, + client_->ComputeConstant(computation)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), + value)); +} + +TEST_F(XlaExpressionTest, GetShape) { + EXPECT_FALSE(XlaExpression().GetShape().ok()); + EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok()); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape, + XlaExpression::Resource(resource_.get()).GetShape()); + EXPECT_EQ(TensorShape({}), resource_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape, + XlaExpression::XlaOp(op_, DT_INT32).GetShape()); + EXPECT_EQ(TensorShape({}), op_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape, + XlaExpression::Constant(constant_).GetShape()); + EXPECT_EQ(TensorShape({}), constant_shape); +} + +TEST_F(XlaExpressionTest, ResolveConstant) { + EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); + EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); + EXPECT_FALSE( + XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional op_constant, + XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); + ASSERT_TRUE(op_constant.has_value()); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + + TF_ASSERT_OK_AND_ASSIGN(absl::optional op_nonconstant, + XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) + .ResolveConstant(client_)); + EXPECT_FALSE(op_nonconstant.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional constant_constant, + XlaExpression::Constant(constant_).ResolveConstant(client_)); + ASSERT_TRUE(constant_constant.has_value()); + test::ExpectTensorEqual(constant_, *constant_constant); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 1398e9ee536a9675e5b703ec3fabf4a8b9d89cbf..5e8006b8d8f63d67e8409cd89d182f8fe61a7441 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,10 +21,10 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { if (kdef->op() == "Const") { - AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + AddDtypeToKernelDefConstraint("dtype", DT_STRING, kdef); } if (kdef->op() == "Assert") { - AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + AddDtypeToKernelDefConstraint("T", DT_STRING, kdef); } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da..c2c0751211180c3715a19d6c78e34659fd18914e 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/types.h" @@ -216,8 +215,7 @@ DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { return dtype; } -xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, +xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 39578144caaadf293d24ea91aa874e56e27ecc01..4858dfee55a393d04cd2af83916eeb40820ee368 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -80,8 +80,7 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, - const xla::XlaOp& operand, + static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type); }; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 86a78ee429e8913edb4a948727fa692083c472f4..fabbcd04fed96ad814d04c2df9394f43bfe0cf99 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile( jit->executable_ = std::move(executable); jit->buffer_infos_ = std::move(buffer_infos); jit->arg_index_table_ = std::move(arg_index_table); - jit->program_shape_ = std::move(program_shape); + jit->program_shape_ = + absl::make_unique(program_shape->ToProto()); jit->static_data_.set_raw_function(raw_function); jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index d3c8f22a8078d03d15447ed200c914390f40b04f..a5392057177e983e11787c31bb496a8947add1e6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction { std::vector arg_names_; std::vector result_names_; - // The backing data for the program shape. - std::unique_ptr program_shape_; + // The backing data for the program shape. The proto form of program shape is + // used because the program shape is serialized and embedded in the object + // file. + std::unique_ptr program_shape_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 6d49298a6f3e8a726695fafc42f3c5341fe98b5f..8846088678b53f6b9ecff0de732d6b5c82392b5a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) { // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); - const xla::ProgramShape* program_shape = function.ProgramShape(); - ASSERT_TRUE(program_shape != nullptr); - ASSERT_EQ(program_shape->parameters_size(), 2); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); - EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + ASSERT_TRUE(function.ProgramShape() != nullptr); + const xla::ProgramShape program_shape(*function.ProgramShape()); + ASSERT_EQ(program_shape.parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32)); - const xla::Shape& result = program_shape->result(); + const xla::Shape& result = program_shape.result(); ASSERT_EQ(result.element_type(), xla::TUPLE); ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dd3498ef7aa242d3ad946cae5f60bc2c8853a342..58808c76de6330a6b28e21dbdead03dea25847f6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -35,40 +36,52 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } +XlaContext* XlaOpKernelContext::xla_context() const { + return &XlaContext::Get(context_); +} + xla::XlaBuilder* XlaOpKernelContext::builder() const { - return XlaContext::Get(this).builder(); + return xla_context()->builder(); +} + +XlaCompiler* XlaOpKernelContext::compiler() const { + return xla_context()->compiler(); } // Retrieves an XlaExpression that was allocated by a previous Op. static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().valid() || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle(); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); return expression; } -// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. -static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { +// Assigns an XlaExpression to a tensor on an XLA compilation device. +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK(!expression->handle().valid()); - return const_cast(expression); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; +} + +const XlaExpression& XlaOpKernelContext::InputExpression(int index) { + return *CastExpressionFromTensor(context_->input(index)); } -// Retrieves the XlaOp from an input Tensor to an Op. This computation was -// constructed by an Op that executed previously and created the output Tensor -// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. -static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); +const XlaExpression& XlaOpKernelContext::InputExpression( + absl::string_view name) { + return *CastExpressionFromTensor(GetInputTensorByName(name)); } -const xla::XlaOp& XlaOpKernelContext::Input(int index) { - return GetComputationFromTensor(context_->input(index)); +xla::XlaOp XlaOpKernelContext::Input(int index) { + return InputExpression(index).AsXlaOp(builder()); } -const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { - return GetComputationFromTensor(GetInputTensorByName(name)); +xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { + return InputExpression(name).AsXlaOp(builder()); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -125,77 +138,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name, Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { - const Tensor& tensor = context_->input(index); - TensorShape new_shape(new_dims); - if (tensor.NumElements() != new_shape.num_elements()) { - return errors::InvalidArgument( - context_->op_kernel().name(), " input ", index, " has shape ", - tensor.shape().DebugString(), - " but was asked to be reshaped to incompatible shape ", - new_shape.DebugString()); - } - const XlaExpression* expression = CastExpressionFromTensor(tensor); - - auto copy_tensor_to_literal = [](const Tensor& tensor, - xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); - - *literal = xla::Literal(literal_shape); - - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&tensor); - memcpy(dst_ptr, src_ptr, total_bytes); - } - return Status::OK(); - }; - - // If the tensor has a known constant value, there is no need to invoke XLA. - if (expression->has_constant_value()) { - Tensor temp(tensor.dtype()); - if (!temp.CopyFrom(expression->constant_value(), new_shape)) { - // This should never happen. The constant should have a shape compatible - // with the enclosing Tensor. - return errors::Internal("Incompatible shapes in ConstantInputReshaped."); - } - - return copy_tensor_to_literal(temp, constant_literal); - } - - // Make sure we treat zero-element tensors as constant. - if (new_shape.num_elements() == 0) { - Tensor temp(tensor.dtype(), new_shape); - - return copy_tensor_to_literal(temp, constant_literal); - } - - xla::XlaOp handle = expression->handle(); - if (new_shape != tensor.shape()) { - // Reshape the handle to the desired shape. - handle = xla::Reshape(handle, new_shape.dim_sizes()); - } - - // The XLA layout is specified minor to major, and TensorFlow's minor - // dimension is the last one. - std::vector layout_indices(new_shape.dims()); - std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - - xla::StatusOr is_constant = builder()->IsConstant(handle); - if (!is_constant.ok()) { - Status status = is_constant.status(); + XlaExpression e = InputExpression(index); + xla::StatusOr> constant_or_status = + e.ResolveConstant(compiler()->client()); + if (!constant_or_status.ok()) { + Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); return status; } - - if (!is_constant.ValueOrDie()) { + absl::optional constant = constant_or_status.ValueOrDie(); + if (!constant.has_value()) { return errors::InvalidArgument( "Input ", index, " to ", context_->op_kernel().type_string(), " operator must be a compile-time constant.\n" @@ -208,25 +162,16 @@ Status XlaOpKernelContext::ConstantInputReshaped( "stateful operation such as a random number generator."); } - // Ask the XLA compiler to evaluate the data handle to a literal. - xla::StatusOr constant_graph = - builder()->BuildConstantSubGraph(handle); - if (!constant_graph.ok()) { - return errors::Internal( - "Error getting a compile-time constant graph for ", - context_->op_kernel().name(), " input ", index, - ".\nError: ", constant_graph.status().error_message()); - } - xla::StatusOr computed = compiler()->client()->ComputeConstant( - constant_graph.ValueOrDie(), &layout); - if (!computed.ok()) { - return errors::Internal("Error evaluating ", context_->op_kernel().name(), - " input ", index, - " as a compile-time constant.\nError: ", - computed.status().error_message()); + Tensor temp(constant->dtype()); + if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + constant->shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + TensorShape(new_dims).DebugString()); } - *constant_literal = std::move(computed).ValueOrDie(); + TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); return Status::OK(); } @@ -322,6 +267,15 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + absl::string_view name, std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -372,7 +326,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(GetComputationFromTensor(input)); + handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -392,8 +346,8 @@ Status XlaOpKernelContext::ConstantInputList( namespace { Status ReadVariableInputTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, TensorShape* shape, - xla::XlaOp* value) { + const XlaOpKernelContext* ctx, + TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); @@ -411,11 +365,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, *shape = variable->shape(); } - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN( - TensorShape representation_shape, - xla_context.RepresentationShape(variable->shape(), variable->type())); - if (representation_shape == variable->shape()) { + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn( + variable->shape(), variable->type())); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); + if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { *value = variable->value(); } else { *value = xla::Reshape(variable->value(), variable->shape().dim_sizes()); @@ -428,15 +384,15 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(context_->input(index), type, context_, shape, + return ReadVariableInputTensor(context_->input(index), type, this, shape, value); } Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, DataType type, TensorShape* shape, xla::XlaOp* value) { - return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, - shape, value); + return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape, + value); } Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, @@ -455,90 +411,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, - Tensor** output) { - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - if (expected_output_dtype(index) == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in its - // value. - // TODO(jpienaar): This should be refactored to stop masquerading - // XlaExpressions as Tensors. - *output = new Tensor(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, *output)); - context_->set_output(index, **output); - } else { - TensorShape tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); - TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); +void XlaOpKernelContext::SetOutputExpression(int index, + const XlaExpression& expression) { + Status status = [&] { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + Tensor* output = nullptr; + // Provides a special behavior for DT_VARIANT: a variant is treated as + // DT_UINT8 scalar as the type to allow mapping for variant to more generic + // types. + if (expression.dtype() == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in + // its value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, output)); + context_->set_output(index, *output); + } else { + TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); + TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); + } + AssignExpressionToTensor(output, expression); + return Status::OK(); + }(); + if (!status.ok()) { + SetStatus(status); } - return Status::OK(); } void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { - // Makes the host Tensor that will refer to the expression. - Tensor* output = nullptr; - auto shape_or = builder()->GetShape(handle); - if (!shape_or.ok()) { - SetStatus(shape_or.status()); - return; - } - - OP_REQUIRES_OK(context_, - allocate_output(index, shape_or.ValueOrDie(), &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); + SetOutputExpression( + index, + XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); } void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { - const TensorShape& shape = constant.shape(); - - xla::BorrowingLiteral literal; - OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - - xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); - CHECK(handle.valid()); - - // Make the Tensor that will refer to the expression. - Tensor* output = nullptr; - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); - expression->set_constant_value(constant); -} - -void XlaOpKernelContext::SetInvalidOutput(int index) { - Tensor* output = nullptr; - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape({}), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::XlaOp handle; - expression->set_handle(handle); + SetOutputExpression(index, XlaExpression::Constant(constant)); } void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { - Tensor* output = nullptr; - // The shape of the output tensor is the shape of the resource itself - // (i.e., a scalar), not the shape of the resource's value. - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape(), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_resource(resource); + SetOutputExpression(index, XlaExpression::Resource(resource)); } Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { @@ -552,7 +471,7 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { namespace { Status AssignVariableTensor(const Tensor& tensor, DataType type, - const OpKernelContext* ctx, xla::XlaOp handle, + const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -569,11 +488,14 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - XlaContext& xla_context = XlaContext::Get(ctx); - TF_ASSIGN_OR_RETURN(TensorShape representation_shape, - xla_context.RepresentationShape(shape, type)); - if (shape != representation_shape) { - handle = xla::Reshape(handle, representation_shape.dim_sizes()); + TF_ASSIGN_OR_RETURN( + xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn(shape, type)); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); + if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { + handle = xla::Reshape(handle, + xla::AsInt64Slice(representation_shape.dimensions())); } return variable->SetValue(handle); } @@ -583,19 +505,15 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(context_->input(input_index), type, context_, - handle, builder()); + return AssignVariableTensor(context_->input(input_index), type, this, handle, + builder()); } Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); - return AssignVariableTensor(GetInputTensorByName(name), type, context_, - handle, builder()); -} - -XlaCompiler* XlaOpKernelContext::compiler() const { - return XlaContext::Get(context_).compiler(); + return AssignVariableTensor(GetInputTensorByName(name), type, this, handle, + builder()); } void XlaOpKernelContext::CtxFailure(const Status& s) { @@ -615,22 +533,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMax(type); + return xla_context()->GetOrCreateMax(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMin(type); + return xla_context()->GetOrCreateMin(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { - return XlaContext::Get(context_).GetOrCreateAdd(type); + return xla_context()->GetOrCreateAdd(type); } const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { - return XlaContext::Get(context_).GetOrCreateMul(type); + return xla_context()->GetOrCreateMul(type); } const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index aa00a454968ad29495e34dc080e55b62bb0b5f7b..1858844bc05a6e12abbf07af83cad816590ddd03 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -60,6 +60,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); + XlaContext* xla_context() const; + // Returns the XLA XlaBuilder containing the output of compilation. xla::XlaBuilder* builder() const; @@ -88,9 +90,9 @@ class XlaOpKernelContext { // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::XlaOp& Input(int index); + xla::XlaOp Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(absl::string_view name); + xla::XlaOp Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -111,14 +113,6 @@ class XlaOpKernelContext { Status ConstantInput(int index, xla::Literal* constant_literal); Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); - // Evaluates input `index`, reshapes it to `new_shape` if new_shape != - // InputShape(index), and stores it in `*constant_literal`. If the input - // cannot be evaluated, e.g., because it depends on unbound parameters, - // returns a non-Ok status. If InputShape(index).num_elements() != - // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, absl::Span new_dims, - xla::Literal* constant_literal); - // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); Status ConstantInputAsIntScalar(absl::string_view name, int64* out); @@ -134,6 +128,8 @@ class XlaOpKernelContext { // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. Status ConstantInputReshapedToIntVector(int index, std::vector* out); + Status ConstantInputReshapedToIntVector(absl::string_view name, + std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); @@ -148,6 +144,10 @@ class XlaOpKernelContext { Status ConstantInputList(absl::string_view name, std::vector* literals); + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + // Outputs int num_outputs() const { return context_->num_outputs(); } @@ -165,9 +165,8 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output `index` to an invalid value. - // Any subsequent attempt to consume this output will cause an error. - void SetInvalidOutput(int index); + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } @@ -255,10 +254,13 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); - // Wraps OpKernelContext's allocate_output method while providing special - // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the - // type to allow mapping for variant to more generic types. - Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + Status ConstantInputReshaped(int index, absl::Span new_dims, + xla::Literal* constant_literal); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 91d48125f1d21092db7e5f9307e44af9c16e4e2b..14237df69081016817fbd1a5332f22996e7f264d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -128,21 +130,26 @@ XlaOpRegistry::~XlaOpRegistry() = default; // Lazily register the CPU and GPU JIT devices the first time // GetCompilationDevice is called. static void* registration_init = [®istry]() { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + mutex_lock lock(registry.mutex_); if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = false; + registration.autoclustering_policy = + cpu_global_jit + ? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; registration.compile_resource_ops = false; } if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; - registration.requires_compilation = false; - registration.enable_jit_by_default = true; + registration.autoclustering_policy = + XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally; registration.compile_resource_ops = false; } return nullptr; @@ -341,18 +348,69 @@ std::vector XlaOpRegistry::DeviceKernels( return ops; } -/* static */ const std::unordered_set* -XlaOpRegistry::CompileTimeConstantInputs(const string& op) { - XlaOpRegistry& registry = Instance(); - mutex_lock lock(registry.mutex_); - auto it = registry.ops_.find(op); - if (it == registry.ops_.end() || it->second.empty()) { - return nullptr; +/* static */ Status XlaOpRegistry::CompileTimeConstantInputs( + const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, + std::vector* result) { + result->clear(); + + DCHECK(op_def != nullptr || op_kernel != nullptr); + + std::unordered_set compile_time_constant_inputs_from_attr; + std::vector compile_time_constant_inputs_vect_from_attr; + + const std::unordered_set* compile_time_constant_inputs; + + if (GetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr, + &compile_time_constant_inputs_vect_from_attr) + .ok()) { + absl::c_copy(compile_time_constant_inputs_vect_from_attr, + std::inserter(compile_time_constant_inputs_from_attr, + compile_time_constant_inputs_from_attr.end())); + compile_time_constant_inputs = &compile_time_constant_inputs_from_attr; + } else { + const string& op = node_def.op(); + + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return Status::OK(); + } else { + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + // + // TODO(sanjoy): This can probably be a std::vector. + compile_time_constant_inputs = + &it->second.front()->compile_time_constant_inputs; + } } - // The test in IsCompatible ensures that if there are multiple matching - // registrations for this op name, they all have the same value of - // compile_time_constant_inputs, so only the first match is returned. - return &it->second.front()->compile_time_constant_inputs; + + for (const string& input : *compile_time_constant_inputs) { + if (op_def) { + NameRangeMap input_name_ranges; + TF_RETURN_IF_ERROR( + NameRangesForNode(node_def, *op_def, &input_name_ranges, nullptr)); + auto name_range = input_name_ranges.find(input); + if (name_range == input_name_ranges.end()) { + continue; + } + + for (int i = name_range->second.first; i < name_range->second.second; + i++) { + result->push_back(i); + } + } else { + int start, stop; + TF_CHECK_OK(op_kernel->InputRange(input, &start, &stop)); + for (int i = start; i < stop; ++i) { + result->push_back(i); + } + } + } + + absl::c_sort(*result); + return Status::OK(); } /*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { @@ -445,7 +503,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( return *this; } -XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstantInput( absl::string_view input_name) { registration_->compile_time_constant_inputs.emplace(input_name); return *this; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 4b2c2bacd647b3e6fe500a942b116772550195ce..0bdd4a1085445420a5147756daac4a54f4725f11 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -66,19 +66,26 @@ class XlaOpRegistry { public: typedef OpKernel* (*Factory)(OpKernelConstruction*); + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. string compilation_device_name; - // Do operators assigned to this device require compilation? - bool requires_compilation; - - // If !requires_compilation, should we try to JIT operators on this device - // when XLA JIT compilation is enabled globally via the SessionOptions? - // (It is still possible to explicitly mark operators to JIT compile, even - // if enable_jit_by_default is false.) - bool enable_jit_by_default; + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; // Enable compilation of operators that use DT_RESOURCE types? bool compile_resource_ops = false; @@ -106,6 +113,7 @@ class XlaOpRegistry { // Registers `device_name` for XLA compilation, using information from // `registration`. + // Does nothing if a registration for `device_name` already exists. static void RegisterCompilationDevice(const string& device_name, const DeviceRegistration& registration); @@ -132,10 +140,27 @@ class XlaOpRegistry { // Returns all operations for which there are XLA kernels on any device. static std::vector GetAllRegisteredOps(); - // Returns the set of compile-time constant inputs to 'op'. Returns nullptr - // if the op is not registered. - static const std::unordered_set* CompileTimeConstantInputs( - const string& op); + // Returns (via `result`) the indices of inputs to `node_def` that must be + // compile-time constants. Returns an empty vector if the op is not + // registered. + // + // `result` is sorted. + static Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpDef& op_def, + std::vector* result) { + return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, + result); + } + + // Returns (via `result`) the indices of inputs to `op_kernel` that must be + // compile-time constants. + // + // `result` is sorted. + static Status CompileTimeConstantInputs(const OpKernel& op_kernel, + std::vector* result) { + return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, + /*op_def=*/nullptr, result); + } // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. @@ -212,6 +237,11 @@ class XlaOpRegistry { // whitelists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); + static Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpKernel* op_kernel, + const OpDef* op_def, + std::vector* result); + // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. @@ -263,7 +293,8 @@ class XlaOpRegistrationBuilder { XlaOpRegistrationBuilder& AllowResourceTypes(); // Mark 'input_name' as an argument whose value must be known at compile-time. - XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + XlaOpRegistrationBuilder& CompileTimeConstantInput( + absl::string_view input_name); // Mark this op as a "metadata" op, one that only looks at the shapes of its // operands and not their values. diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 56c2e01055665954b99ea635e56666fbd8b96026..48a3c012727acd8472d3d5d4072ae700f5497d96 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -26,10 +27,44 @@ limitations under the License. namespace tensorflow { +/*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) { + switch (kind) { + case XlaResource::kInvalid: + return "invalid"; + case XlaResource::kVariable: + return "variable"; + case XlaResource::kStack: + return "stack"; + case XlaResource::kTensorArray: + return "tensorarray"; + } +} + +/*static*/ std::unique_ptr XlaResource::CreateStack( + string name, DataType type, int64 max_size) { + return absl::make_unique( + XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(), + /*initial_value=*/xla::XlaOp(), + /*max_array_size=*/max_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + +/*static*/ std::unique_ptr XlaResource::CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size) { + return absl::make_unique( + XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape, + initial_value, max_array_size, + /*tensor_array_gradients=*/std::set{}, + /*tensor_array_multiple_writes_aggregate=*/false); +} + XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, - const std::set& tensor_array_gradients) + int64 max_array_size, + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), @@ -37,14 +72,17 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, shape_(std::move(shape)), value_(initial_value), initial_value_(initial_value), - tensor_array_size_(tensor_array_size) { + max_array_size_(max_array_size), + tensor_array_multiple_writes_aggregate_( + tensor_array_multiple_writes_aggregate) { CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, - xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); + xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } } @@ -96,7 +134,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kTensorArray: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); @@ -104,7 +142,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kStack: { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), @@ -129,15 +167,16 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - ta_shape.AddDim(tensor_array_size_); + ta_shape.AddDim(max_array_size_); ta_shape.AppendShape(shape_); xla::XlaOp gradient_value = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), - type_, shape_, gradient_value, tensor_array_size_, - /*tensor_array_gradients=*/{})); + type_, shape_, gradient_value, max_array_size_, + /*tensor_array_gradients=*/{}, + /*tensor_array_multiple_writes_aggregate=*/true)); } *gradient_out = gradient.get(); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 2438490be13809b9f3571a362900b44cb838e76b..736588bb8b89ba756cdce77eeebff8d1fcf4774c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -35,11 +36,22 @@ class XlaResource { kTensorArray, kStack, }; + static absl::string_view KindToString(Kind kind); + + // Creates a new Stack resource. + static std::unique_ptr CreateStack(string name, DataType type, + int64 max_size); + + // Creates a new TensorArray resource. + static std::unique_ptr CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64 max_array_size); XlaResource(Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& initial_value, - int64 tensor_array_size, - const std::set& tensor_array_gradients); + int64 max_array_size, + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -113,13 +125,19 @@ class XlaResource { const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields + // TODO(phawkins): refactor this code to use subclasses, rather than putting + // kind-specific fields in XlaResource. - // 'tensor_array_size' stores the expected size of the TensorArray or Stack. + // 'max_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. // Used by both TensorArrays and Stacks. - int64 tensor_array_size() const { return tensor_array_size_; } - void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + int64 max_array_size() const { return max_array_size_; } + void set_max_array_size(int64 size) { max_array_size_ = size; } + + bool tensor_array_multiple_writes_aggregate() const { + return tensor_array_multiple_writes_aggregate_; + } // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes // to an XlaResource containing the gradient TensorArrays. We store a pointer @@ -142,7 +160,8 @@ class XlaResource { xla::XlaOp value_; xla::XlaOp initial_value_; - int64 tensor_array_size_ = -1; + int64 max_array_size_ = -1; + bool tensor_array_multiple_writes_aggregate_ = false; std::map> tensor_array_gradients_; }; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index cc7390c6e60375b4c31c38f9f7dee25730f8f51e..4360e0857964b0ac63fc887e269b04a4b00d854a 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -7,6 +7,7 @@ package_group( packages = [ "//tensorflow/compiler/...", "//tensorflow/contrib/tpu/...", + "//third_party/py/jax/...", ], ) @@ -67,7 +68,7 @@ cc_library( visibility = [":friends"], deps = [ ":xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", ], ) @@ -225,12 +226,14 @@ cc_library( "index_util.cc", "layout_util.cc", "primitive_util.cc", + "shape.cc", "shape_util.cc", ], hdrs = [ "index_util.h", "layout_util.h", "primitive_util.h", + "shape.h", "shape_util.h", ], visibility = ["//visibility:public"], @@ -253,6 +256,23 @@ cc_library( ], ) +tf_cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "shape_util_test", srcs = ["shape_util_test.cc"], @@ -308,6 +328,7 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -330,6 +351,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -373,6 +395,7 @@ cc_library( ":literal_util", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -731,6 +754,72 @@ tf_cc_test( ], ) +cc_library( + name = "parse_flags_from_env", + srcs = ["parse_flags_from_env.cc"], + hdrs = ["parse_flags_from_env.h"], + deps = + [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "parse_flags_from_env_test", + srcs = ["parse_flags_from_env_test.cc"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "debug_options_flags", + srcs = [ + "debug_options_flags.cc", + "debug_options_parsers.h", + ], + hdrs = ["debug_options_flags.h"], + deps = + [ + ":parse_flags_from_env", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "debug_options_parsers_test", + size = "small", + srcs = [ + "debug_options_parsers.h", + "debug_options_parsers_test.cc", + ], + deps = + [ + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index 39f8caaa961dc7b57d2b45f974fc6ecf89cf6748..f9c93707f7af30a0fa0c4224240dc40848a24f66 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -1,7 +1,6 @@

- +

XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear -algebra that optimizes TensorFlow computations. See the -[documentation](https://www.tensorflow.org/performance/xla/) for more details. +algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/overview.md). diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 782c966b4c57672d137569a318fb20ace14d493b..e4aca98f67d50287a83afc6f41a59458f3df2da2 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -104,7 +104,7 @@ std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); - auto set = [&array, n1, n2](int64 index, NativeT value) { + auto set = [&array, n2](int64 index, NativeT value) { (*array)(index / n2, index % n2) = value; }; for (int64 i = 0; i < count - 1; ++i) { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index dc097f3696e22d75d7dc72ec4877a9c8b5dda059..fe99564d3c671cd7890e1fa26fcd2e3384972983 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -33,6 +33,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", ], ) @@ -66,6 +68,7 @@ cc_library( deps = [ ":global_data", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:service_interface", @@ -74,11 +77,11 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -88,11 +91,12 @@ cc_library( srcs = ["executable_build_options.cc"], hdrs = ["executable_build_options.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", @@ -189,6 +193,7 @@ cc_library( hdrs = ["xla_computation.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -234,13 +239,13 @@ tf_cc_test( deps = [ ":xla_builder", ":xla_computation", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 5dde5b432f136c16d4e3795569499ee5de709763..74b76f929949d3300a5d0ff45d5fa4cd9f162642 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -20,9 +20,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -42,7 +43,7 @@ StatusOr Client::Transfer(const GlobalData& data, TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferToClientResponse response; @@ -123,7 +124,7 @@ StatusOr Client::TransferFromOutfeed( } request.set_replica_id(replica_id); if (shape_with_layout != nullptr) { - *request.mutable_shape_with_layout() = *shape_with_layout; + *request.mutable_shape_with_layout() = shape_with_layout->ToProto(); } TransferFromOutfeedResponse response; @@ -170,11 +171,14 @@ StatusOr Client::ExecuteAndTransfer( std::unique_ptr data, Execute(computation, arguments, execution_options, execution_profile)); - const Shape* shape_with_output_layout = nullptr; + absl::optional shape_with_output_layout; if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); + shape_with_output_layout = + Shape(execution_options->shape_with_output_layout()); } - return Transfer(*data, shape_with_output_layout); + return Transfer(*data, shape_with_output_layout.has_value() + ? &(*shape_with_output_layout) + : nullptr); } StatusOr Client::ComputeConstant(const XlaComputation& computation, @@ -210,11 +214,10 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const XlaComputation& computation, absl::Span arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteGraphRequest request; +StatusOr Client::Compile( + const XlaComputation& computation, absl::Span argument_shapes, + const ExecutionOptions* execution_options) { + CompileRequest request; *request.mutable_computation() = computation.proto(); if (execution_options == nullptr) { @@ -222,6 +225,34 @@ StatusOr> Client::Execute( } else { *request.mutable_execution_options() = *execution_options; } + if (request.execution_options().device_handles_size() > 1) { + return InvalidArgument( + "Compiling with multiple device handles is not supported. Use " + "'Execute' instead."); + } + + // The argument shapes affect how the computation is compiled. + for (const auto& arg_shape : argument_shapes) { + *request.add_input_shape_with_layout() = arg_shape.ToProto(); + } + + CompileResponse response; + VLOG(1) << "making compile request: " << request.ShortDebugString(); + Status s = stub_->Compile(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + TF_RET_CHECK(response.has_handle()); + return response.handle(); +} + +StatusOr> Client::Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile) { + ExecuteRequest request; + *request.mutable_handle() = handle; for (GlobalData* argument : arguments) { CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); @@ -229,7 +260,7 @@ StatusOr> Client::Execute( ExecuteResponse response; VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->ExecuteGraph(&request, &response); + Status s = stub_->Execute(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -238,15 +269,62 @@ StatusOr> Client::Execute( if (execution_profile != nullptr) { *execution_profile = response.profile(); + } + + return absl::make_unique(stub_, response.output()); +} + +StatusOr> Client::Execute( + const XlaComputation& computation, absl::Span arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + if (execution_options != nullptr && + execution_options->device_handles_size() > 1) { + std::vector computation_instances = { + XlaComputationInstance{ + computation, + std::vector(arguments.begin(), arguments.end()), + *execution_options, execution_profile}}; + TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances)); + // The result selection is a bit hacky, but better than assuming it is + // device 0. + // + // TODO(b/118493728): Allow Execute to return one result per computation. + for (int64 i = 0; i < results.size(); i++) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); + if (!ShapeUtil::IsEmptyTuple(shape)) { + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return std::move(results[i]); + } + } + TF_RET_CHECK(!results.empty()); + VLOG(1) << "Defaulting to device 0 result"; + return std::move(results[0]); + } + + // The argument shapes affect how the computation is compiled. + std::vector arg_shapes(arguments.size()); + for (int i = 0; i < arguments.size(); i++) { + TF_ASSIGN_OR_RETURN(arg_shapes[i], GetShape(*arguments[i])); + } + + TF_ASSIGN_OR_RETURN(auto handle, + Compile(computation, arg_shapes, execution_options)); + + TF_ASSIGN_OR_RETURN(auto result, + Execute(handle, arguments, execution_profile)); + + if (execution_profile != nullptr) { if (VLOG_IS_ON(1)) { TF_ASSIGN_OR_RETURN( auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); + ExecutionStatsAsString(computation, *execution_profile)); VLOG(1) << execution_stats; } } - return absl::make_unique(stub_, response.output()); + return std::move(result); } StatusOr>> Client::ExecuteParallel( @@ -274,10 +352,11 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { + for (size_t i = 0; i < response.responses_size(); ++i) { outputs.push_back( absl::make_unique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { + if (i < computations.size() && + computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } } @@ -312,7 +391,7 @@ StatusOr> Client::GetDeviceHandles( Status Client::Unregister(const GlobalData& data) { UnregisterRequest request; - *request.mutable_data() = data.handle(); + *request.add_data() = data.handle(); UnregisterResponse response; VLOG(1) << "making unregister request"; @@ -383,15 +462,14 @@ StatusOr Client::GetShape(const GlobalData& data) { return s; } - return response.shape(); + return Shape(response.shape()); } StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); + GetComputationStats(computation, GetDebugOptionsFromFlags())); int64 total_flops = computation_stats.flop_count() + computation_stats.transcendental_count(); if (profile.compute_time_ns() > 0) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 6f4d33c469f1f885cfeef546e3981dc3417ef71f..d0ac4703c632e0e01d3c8911594b46fedf28930d 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -40,6 +40,31 @@ class Client { explicit Client(ServiceInterface* stub); virtual ~Client(); + // Compile the computation with the given argument shapes and returns the + // handle to the compiled executable. The compiled executable is cached on the + // service, and the returned handle can be used for exection without + // re-compile. + // * The shape and layout of the arguments being executed with will affect how + // the computation is compiled. If argument_shapes is empty, the parameters' + // shape and layout will be used in the compilation. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_options.device_handles should be empty. If you need + // non-empty device handles, call 'Execute' instead. + StatusOr Compile( + const XlaComputation& computation, + absl::Span argument_shapes, + const ExecutionOptions* execution_options = nullptr); + + // Executes the compiled executable for the given handle with the given + // arguments and returns the global data that was produced from the execution. + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + StatusOr> Execute( + const ExecutionHandle& handle, absl::Span arguments, + ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global // data that was produced from the execution. // * If execution_options is not nullptr, these options are passed to the diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 0f1745366b7c33e573aff2e66d85431b01488c49..1f594e551af381d7537e947892cbf7e0b5b3b861 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" namespace xla { @@ -39,6 +40,13 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } +DebugOptions* ExecutableBuildOptions::mutable_debug_options() { + if (!has_debug_options()) { + debug_options_ = GetDebugOptionsFromFlags(); + } + return &debug_options_.value(); +} + ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( const Shape& shape_with_layout) { result_layout_set_ = true; @@ -55,68 +63,10 @@ string ExecutableBuildOptions::ToString() const { if (result_layout_set_) { result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); } - string generate_hlo_graph = "nullopt"; - if (generate_hlo_graph_.has_value()) { - generate_hlo_graph = generate_hlo_graph_.value(); - } return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout, generate_hlo_graph); -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( - string regex) { - generate_hlo_graph_ = std::move(regex); - return *this; -} - -const absl::optional& ExecutableBuildOptions::generate_hlo_graph() - const { - return generate_hlo_graph_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - absl::string_view dirpath) { - dump_optimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { - return dump_optimized_hlo_proto_to_; -} - -ExecutableBuildOptions& -ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath) { - dump_unoptimized_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { - return dump_unoptimized_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath) { - dump_per_pass_hlo_proto_to_ = string(dirpath); - return *this; -} - -const absl::optional& -ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { - return dump_per_pass_hlo_proto_to_; -} - -ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { - hlo_profile_ = enabled; - return *this; -} - -absl::optional ExecutableBuildOptions::hlo_profile() const { - return hlo_profile_; + device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph()); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 93334db88bc24f2ffbf3c7a57ee45ef238286739..a58090253bfac7779e4b61bc7231a0f0d945cc00 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -19,7 +19,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -44,6 +46,12 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; + // Expose access to the XLA debug options which will be passed to the + // compilation process. + bool has_debug_options() const { return debug_options_.has_value(); } + const DebugOptions& debug_options() const { return *debug_options_; } + DebugOptions* mutable_debug_options(); + // If set, this specifies an allocator that can be used to allocate temporary // space on the device during compilation. For example, the compiler might // want to run various algorithms on the device and pick the fastest one -- it @@ -55,56 +63,16 @@ class ExecutableBuildOptions { DeviceMemoryAllocator* allocator); DeviceMemoryAllocator* device_allocator() const; - // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). - ExecutableBuildOptions& set_generate_hlo_graph(string regex); - const absl::optional& generate_hlo_graph() const; - - // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_optimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO - // protobuf to (as in DebugOptions). - ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_unoptimized_hlo_proto_to() const; - - // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs - // to (as in DebugOptions). - ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - absl::string_view dirpath); - const absl::optional& dump_per_pass_hlo_proto_to() const; - - // If true, specifies that we should record an HLO profile during execution - // and log it after execution (as in DebugOptions). If nullopt the default is - // used. - ExecutableBuildOptions& set_hlo_profile(bool enabled); - absl::optional hlo_profile() const; - - void add_disabled_hlo_pass(absl::string_view pass_name) { - disabled_hlo_passes_.push_back(std::string(pass_name)); - } - const absl::Span disabled_hlo_passes() const { - return disabled_hlo_passes_; - } - // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: - absl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - absl::optional generate_hlo_graph_; - absl::optional dump_optimized_hlo_proto_to_; - absl::optional dump_unoptimized_hlo_proto_to_; - absl::optional dump_per_pass_hlo_proto_to_; + absl::optional debug_options_; DeviceMemoryAllocator* device_allocator_ = nullptr; - std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 2986d4060013703873b2cffb6aacbb012606d16f..f1fa13d95c035d182746d3ce5400178890aa42b1 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -18,25 +18,53 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { + +// Releases a set of global data handles owned by the parent service +// interface. +void ReleaseHandles(ServiceInterface* parent, + const absl::Span handles) { + UnregisterRequest request; + for (auto& handle : handles) { + VLOG(1) << "Requesting to unregister " << handle.ShortDebugString(); + *request.add_data() = handle; + } + UnregisterResponse response; + Status status = parent->Unregister(&request, &response); + VLOG(1) << "Done with request"; + if (!status.ok()) { + LOG(WARNING) << "Failed to unregister handles: " << status + << "; continuing anyway..."; + } +} + +} // namespace GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle) : handle_(std::move(handle)), parent_(parent) {} GlobalData::~GlobalData() { - UnregisterRequest request; - *request.mutable_data() = handle_; - UnregisterResponse response; - VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - Status s = parent_->Unregister(&request, &response); - VLOG(1) << "done with request"; + if (parent_ != nullptr) { + ReleaseHandles(parent_, {handle_}); + } +} - if (!s.ok()) { - LOG(WARNING) << "failed to unregister " << handle_.ShortDebugString() - << "; continuing anyway..."; +/* static */ void GlobalData::Release( + std::vector> instances) { + absl::flat_hash_map> + parent_handles_map; + for (auto& instance : instances) { + if (instance->parent_ != nullptr) { + parent_handles_map[instance->parent_].push_back(instance->Release()); + } + } + for (auto& parent_handles : parent_handles_map) { + ReleaseHandles(parent_handles.first, parent_handles.second); } } diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h index b7929357d06032b55c04bf0391f7fa703ee15f17..4d48d2c53fc6171fe1940924598a4d48519c5adf 100644 --- a/tensorflow/compiler/xla/client/global_data.h +++ b/tensorflow/compiler/xla/client/global_data.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_ +#include +#include + +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,7 +40,18 @@ class GlobalData { const GlobalDataHandle& handle() const { return handle_; } + // Releases a set of GlobalData handles. A single RPC will be issued + // per unique ServiceInterface of the given GlobalData objects. + static void Release(std::vector> instances); + private: + // Detaches the global data handle from the object, such that the destructor + // will not try to release it. + GlobalDataHandle Release() { + parent_ = nullptr; + return handle_; + } + GlobalDataHandle handle_; // Handle being wrapped. ServiceInterface* parent_; // Service used to unregister handle_. diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index a18c94c4e695a6cdcb9dcc60b64b617cecd276d8..41db8de29ff0085a30847ff41db4ffbfc774e2a1 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -104,13 +104,17 @@ xla_test( ) cc_library( - name = "numeric", - srcs = ["numeric.cc"], - hdrs = ["numeric.h"], + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], deps = [ ":arithmetic", ":constants", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "@com_google_absl//absl/types:span", @@ -118,11 +122,12 @@ cc_library( ) xla_test( - name = "numeric_test", - srcs = ["numeric_test.cc"], + name = "matrix_test", + srcs = ["matrix_test.cc"], tags = ["enable_for_xla_interpreter"], deps = [ - ":numeric", + ":matrix", + ":slicing", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -164,11 +169,43 @@ cc_library( deps = [ ":constants", ":math", - ":numeric", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", + "@com_google_absl//absl/base", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + tags = ["enable_for_xla_interpreter"], + deps = [ + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -177,8 +214,9 @@ cc_library( srcs = ["sorting.cc"], hdrs = ["sorting.h"], deps = [ - ":numeric", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", ], @@ -187,10 +225,6 @@ cc_library( xla_test( name = "sorting_test", srcs = ["sorting_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], tags = ["enable_for_xla_interpreter"], deps = [ ":sorting", @@ -224,3 +258,48 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9..36fdda39b4124b9100c6054160f9c17bdf787d6f 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -265,6 +265,21 @@ XlaOp Digamma(XlaOp input) { return result; } +// Implements Banker's rounding: numbers that are equidistant between two +// integers are rounded towards even. +XlaOp RoundToEven(XlaOp x) { + auto half = ScalarLike(x, 0.5); + auto one = ScalarLike(x, 1.0); + auto two = ScalarLike(x, 2.0); + + auto round_val = Floor(x); + auto fraction = x - round_val; + auto nearest_even_int = round_val - two * Floor(half * x); + auto is_odd = Eq(nearest_even_int, one); + return Select(Or(Gt(fraction, half), And(Eq(fraction, half), is_odd)), + round_val + one, round_val); +} + // Trigonometric functions. // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) @@ -304,4 +319,13 @@ XlaOp Cosh(XlaOp x) { return (Exp(x) + Exp(-x)) * ScalarLike(x, 0.5); } XlaOp Sinh(XlaOp x) { return (Exp(x) - Exp(-x)) * ScalarLike(x, 0.5); } +XlaOp MaybeConjugate(XlaOp x, bool conjugate) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == C64 && conjugate; + return perform_conj ? Conj(x) : x; + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index a6cafd42077367bf23ffa1f45eab31c01dc31b16..17612bf9fdc0f1eabb338671c93c025c5b268872 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -51,6 +51,10 @@ XlaOp Lgamma(XlaOp input); // Computes an approximation of the digamma function. XlaOp Digamma(XlaOp input); +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + // Trigonometric functions // Computes the arc cosine of 'x'. @@ -82,6 +86,10 @@ XlaOp Cosh(XlaOp x); // Computes the hyperbolic sine of 'x'. XlaOp Sinh(XlaOp x); +// Applies a complex conjugation operation if `a` is complex and `conjugate` +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_ diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 14c259a7fa2a47642663b65d2785e5bbdc040cfd..ae2ea225d1aadd7b3a794eabeca866c498f34760 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -136,5 +136,17 @@ XLA_TEST_F(MathTest, Digamma) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, RoundToEven) { + XlaBuilder builder(TestName()); + auto x = ConstantR1( + &builder, {-1.4, -1.5, -2.5, -0.5, 0, 0.5, 1.5, 2.5, 3.5, 4.5}); + RoundToEven(x); + + std::vector expected = {-1.0, -2.0, -2.0, -0.0, 0, + 0.0, 2.0, 2.0, 4.0, 4.0}; + + ComputeAndCompareR1(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.cc b/tensorflow/compiler/xla/client/lib/matrix.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffd744d190885b8e3f4149a48a706498b3787618 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.cc @@ -0,0 +1,185 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, + int64 n) { + auto a = Iota(builder, type, m); + auto b = Iota(builder, type, n); + auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); + return ConvertElementType(indicator, type); +} + +XlaOp GetMatrixDiagonal(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + auto mask = Broadcast(indicator, major_dims); + + // TPUs don't support S64 add reduction at the moment. But fortunately + // OR-reductions work just as well for integers. + XlaComputation reducer = + primitive_util::IsIntegralType(shape.element_type()) + ? CreateScalarOrComputation(shape.element_type(), builder) + : CreateScalarAddComputation(shape.element_type(), builder); + + return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), + reducer, {m >= n ? n_dims - 2 : n_dims - 1}); + }); +} + +XlaOp Triangle(XlaOp x, bool lower) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + const int64 m = shape.dimensions(n_dims - 2); + const int64 n = shape.dimensions(n_dims - 1); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); + auto a = Iota(builder, U32, n); + auto b = Iota(builder, U32, m); + XlaOp indicator; + if (lower) { + indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } else { + indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); + } + auto mask = Broadcast(indicator, major_dims); + + return Select(mask, x, Zeros(builder, shape)); + }); +} + +XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } + +XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } + +XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (ShapeUtil::Rank(x_shape) != ShapeUtil::Rank(y_shape)) { + return InvalidArgument( + "Arguments to BatchDot have different ranks: %s vs. %s", + ShapeUtil::HumanString(x_shape), ShapeUtil::HumanString(y_shape)); + } + const int ndims = ShapeUtil::Rank(x_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to BatchDot must have rank >= 2: got %d", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { + return InvalidArgument( + "Dimension %d of inputs to BatchDot must be equal: shapes %s vs %s", + i, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = ndims - 1; + int y_inner_dim = ndims - 2; + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { + return InvalidArgument( + "Dimensions %d and %d of arguments to BatchDot must be equal: " + "shapes %s vs %s", + x_inner_dim, y_inner_dim, ShapeUtil::HumanString(x_shape), + ShapeUtil::HumanString(y_shape)); + } + + // Check for zero lhs/rhs dim size. + if (ShapeUtil::IsZeroElementArray(x_shape) || + ShapeUtil::IsZeroElementArray(y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = ndims - 2; + int y_outer_dim = ndims - 1; + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); + return Broadcast( + ConstantLiteral(builder, LiteralUtil::Zero(x_shape.element_type())), + dimensions); + } + + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + + return DotGeneral(x, y, dot_dnums, &precision_proto); + }); +} + +XlaOp TransposeInMinorDims(XlaOp x) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return Transpose(x, permutation); + }); +} + +XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) { + return transpose ? TransposeInMinorDims(x) : x; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/matrix.h b/tensorflow/compiler/xla/client/lib/matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..8856f99c7a0fee8f315aac11fab392cf5536f57b --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); + +// Get the diagonals of the last two dimensions. If 'x' has shape +// [..., M, N], then the output has shape [..., min(M, N)], containing the +// diagonal elements (i.e., with indices [..., i, i]). +XlaOp GetMatrixDiagonal(XlaOp x); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0593a7517ac125ca8dc5395cee76f6bc23232cd3 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/matrix.h" + +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MatrixTest : public ClientLibraryTestBase { + protected: + template + void TestMatrixDiagonal(); +}; + +XLA_TEST_F(MatrixTest, Triangle) { + XlaBuilder builder(TestName()); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + LowerTriangle(a); + Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, + {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}); +} + +template +void MatrixTest::TestMatrixDiagonal() { + XlaBuilder builder("GetMatrixDiagonal"); + Array3D input(2, 3, 4); + input.FillIota(0); + + XlaOp a; + auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); + GetMatrixDiagonal(a); + Array2D expected({{0, 5, 10}, {12, 17, 22}}); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}); +} + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } + +XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } + +Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(MatrixTest, RowBatchDot) { + XlaBuilder builder(TestName()); + + int n = 4; + + XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + auto l_index = DynamicSliceInMinorDims( + a, {index, ConstantR0(&builder, 0)}, {1, n}); + BatchDot(l_index, TransposeInMinorDims(row)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc deleted file mode 100644 index 377654220b5df4487e9e194361473d54ff46a54e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/client/lib/arithmetic.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" - -namespace xla { - -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, - int64 n) { - auto a = Iota(builder, type, m); - auto b = Iota(builder, type, n); - auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0}); - return ConvertElementType(indicator, type); -} - -XlaOp GetMatrixDiagonal(XlaOp x) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - auto mask = Broadcast(indicator, major_dims); - - // TPUs don't support S64 add reduction at the moment. But fortunately - // OR-reductions work just as well for integers. - XlaComputation reducer = - primitive_util::IsIntegralType(shape.element_type()) - ? CreateScalarOrComputation(shape.element_type(), builder) - : CreateScalarAddComputation(shape.element_type(), builder); - - return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0), - reducer, {m >= n ? n_dims - 2 : n_dims - 1}); - }); -} - -XlaOp Triangle(XlaOp x, bool lower) { - XlaBuilder* builder = x.builder(); - return builder->ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); - const int64 n_dims = ShapeUtil::Rank(shape); - TF_RET_CHECK(n_dims >= 2); - const int64 m = shape.dimensions(n_dims - 2); - const int64 n = shape.dimensions(n_dims - 1); - absl::Span major_dims = - AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); - auto a = Iota(builder, U32, n); - auto b = Iota(builder, U32, m); - xla::XlaOp indicator; - if (lower) { - indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } else { - indicator = Le(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); - } - auto mask = Broadcast(indicator, major_dims); - - return Select(mask, x, Zeros(builder, shape)); - }); -} - -XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } - -XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h deleted file mode 100644 index efd8cdc25724198633e0bf1c48c4e7d9e4b4c9e1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...]. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); - -// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere -// else. -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); - -// Get the diagonals of the last two dimensions. If 'x' has shape -// [..., M, N], then the output has shape [..., min(M, N)], containing the -// diagonal elements (i.e., with indices [..., i, i]). -XlaOp GetMatrixDiagonal(XlaOp x); - -// Get the upper or lower triangle part of the last two dimensions -XlaOp Triangle(XlaOp x, bool lower); - -// Get the upper triangle part of the last two dimensions -XlaOp UpperTriangle(XlaOp x); - -// Get the lower triangle part of the last two dimensions -XlaOp LowerTriangle(XlaOp x); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc deleted file mode 100644 index 7d6aedd49462bd4f075f90d0b0f85c40f1191aa1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace { - -class NumericTest : public ClientLibraryTestBase { - protected: - template - void TestMatrixDiagonal(); -}; - -XLA_TEST_F(NumericTest, Triangle) { - XlaBuilder builder(TestName()); - Array3D input(2, 3, 4); - input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - LowerTriangle(a); - Array3D expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}}, - {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}}); - - ComputeAndCompareR3(&builder, expected, {a_data.get()}); -} - -template -void NumericTest::TestMatrixDiagonal() { - XlaBuilder builder("GetMatrixDiagonal"); - Array3D input(2, 3, 4); - input.FillIota(0); - - XlaOp a; - auto a_data = CreateR3Parameter(input, 0, "a", &builder, &a); - GetMatrixDiagonal(a); - Array2D expected({{0, 5, 10}, {12, 17, 22}}); - - ComputeAndCompareR2(&builder, expected, {a_data.get()}); -} - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal(); } - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal(); } - -XLA_TEST_F(NumericTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal(); } - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 6ef81689489d8117d5951bcb75693c2e3413e4d6..85b9e1827dcef5ed907d893277deb5a52f8f30e9 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -15,20 +15,19 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/casts.h" namespace xla { namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. -XlaOp RotateLeftS32(XlaOp v, int distance) { - return (v << ConstantR0(v.builder(), distance)) | - ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); +XlaOp RotateLeftU32(XlaOp v, int distance) { + return (v << ConstantR0(v.builder(), distance)) | + ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); } using ThreeFry2x32State = std::array; @@ -38,13 +37,16 @@ using ThreeFry2x32State = std::array; // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { XlaBuilder* builder = input[0].builder(); + key[0] = BitcastConvertType(key[0], U32); + key[1] = BitcastConvertType(key[1], U32); + // Rotation distances specified by the Threefry2x32 algorithm. constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; ThreeFry2x32State x; std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. - ks[2] = ConstantR0(builder, 0x1BD11BDA); + ks[2] = ConstantR0(builder, 0x1BD11BDA); for (int i = 0; i < 2; ++i) { ks[i] = key[i]; x[i] = input[i]; @@ -58,7 +60,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { // amount 'rotation'. auto round = [](ThreeFry2x32State v, int rotation) { v[0] = v[0] + v[1]; - v[1] = RotateLeftS32(v[1], rotation); + v[1] = RotateLeftU32(v[1], rotation); v[1] = v[0] ^ v[1]; return v; }; @@ -70,74 +72,83 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { x = round(x, rotations[2]); x = round(x, rotations[3]); x[0] = x[0] + ks[1]; - x[1] = x[1] + ks[2] + ConstantR0(builder, 1); + x[1] = x[1] + ks[2] + ConstantR0(builder, 1); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); x[0] = x[0] + ks[2]; - x[1] = x[1] + ks[0] + ConstantR0(builder, 2); + x[1] = x[1] + ks[0] + ConstantR0(builder, 2); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); x[0] = x[0] + ks[0]; - x[1] = x[1] + ks[1] + ConstantR0(builder, 3); + x[1] = x[1] + ks[1] + ConstantR0(builder, 3); x = round(x, rotations[4]); x = round(x, rotations[5]); x = round(x, rotations[6]); x = round(x, rotations[7]); x[0] = x[0] + ks[1]; - x[1] = x[1] + ks[2] + ConstantR0(builder, 4); + x[1] = x[1] + ks[2] + ConstantR0(builder, 4); x = round(x, rotations[0]); x = round(x, rotations[1]); x = round(x, rotations[2]); x = round(x, rotations[3]); x[0] = x[0] + ks[2]; - x[1] = x[1] + ks[0] + ConstantR0(builder, 5); + x[1] = x[1] + ks[0] + ConstantR0(builder, 5); return x; } -} // namespace +// Returns the inputs with unique counter values for ThreeFry2x32. +ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) { + ThreeFry2x32State inputs; + inputs[0] = Iota(builder, U32, size); + inputs[1] = inputs[0] + ConstantR0(builder, size); + return inputs; +} -XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, - XlaOp minval, XlaOp maxval) { - XlaBuilder* builder = seeds[0].builder(); - if (shape.element_type() != F32) { - return builder->ReportError(Unimplemented( - "Types other than F32 are not implemented by StatelessRngUniform.")); - } - ThreeFry2x32State key = seeds; +XlaOp StatelessRngUniformU32(std::array key, const Shape& shape) { + XlaBuilder* builder = key[0].builder(); const int64 size = ShapeUtil::ElementsIn(shape); - const int64 half_size = CeilOfRatio(size, 2); const bool size_is_odd = (half_size * 2 != size); - - // Fill the generator inputs with unique counter values. - ThreeFry2x32State inputs; - inputs[0] = Iota(builder, S32, half_size); - inputs[1] = inputs[0] + ConstantR0(builder, half_size); + ThreeFry2x32State inputs = GetInputs(half_size, builder); ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); - if (size_is_odd) { outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); } + auto result = ConcatInDim(builder, outputs, 0); + return Reshape(result, AsInt64Slice(shape.dimensions())); +} - auto bits = Reshape(ConcatInDim(builder, outputs, 0), - AsInt64Slice(shape.dimensions())); +XlaOp StatelessRngUniformU64(std::array key, const Shape& shape) { + XlaBuilder* builder = key[0].builder(); + const int64 size = ShapeUtil::ElementsIn(shape); + ThreeFry2x32State inputs = GetInputs(size, builder); + ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); + // low 32 bit: outputs[0], high 32 bit: outputs[1] + auto result = ConvertElementType(outputs[0], U64) | + ShiftLeft(ConvertElementType(outputs[1], U64), + ConstantR0WithType(builder, U64, 32)); + return Reshape(result, AsInt64Slice(shape.dimensions())); +} + +XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = bits.builder(); // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; constexpr int kMantissaBits = 23; bits = ShiftRightLogical( - bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | - ConstantR0(builder, tensorflow::bit_cast(1.0f)); + bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | + ConstantR0(builder, absl::bit_cast(1.0f)); auto floats = BitcastConvertType(bits, F32); // We have a floating point number in the range [1.0, 2.0). @@ -147,4 +158,47 @@ XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, return floats * (maxval - minval) + minval; } +XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, + PrimitiveType type, PrimitiveType unsigned_type) { + XlaBuilder* builder = bits.builder(); + // TODO(b/72573764): Generate real uniform integer distribution. + // The following algorithm is the same one that TF uses right now, but it's + // uniform only when maxval - minval is a divisor of the range that bits is + // generated from. + auto range = BitcastConvertType(maxval, unsigned_type) - + BitcastConvertType(minval, unsigned_type); + auto dist = Rem(bits, range); + auto dist_div_2 = + ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1)); + + return minval + BitcastConvertType(dist_div_2, type) + + BitcastConvertType(dist - dist_div_2, type); +} + +} // namespace + +XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, + XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = seeds[0].builder(); + PrimitiveType type = shape.element_type(); + switch (type) { + case F32: { + auto bits = StatelessRngUniformU32(seeds, shape); + return StatelessRngUniformF32(bits, minval, maxval); + } + case S32: { + auto bits = StatelessRngUniformU32(seeds, shape); + return StatelessRngUniformInt(bits, minval, maxval, type, U32); + } + case S64: { + auto bits = StatelessRngUniformU64(seeds, shape); + return StatelessRngUniformInt(bits, minval, maxval, type, U64); + } + default: + return builder->ReportError(Unimplemented( + "Types other than F32, S32 and S64 are not implemented by " + "StatelessRngUniform.")); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index ad000b1fa1d0655c8fccc0bb33379f2499b77f26..2603818de26888566a533334e49b039b126db66e 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -25,7 +25,7 @@ namespace xla { // Returns a tensor containing 'shape' random values uniformly distributed in // the range [minval, maxval). Requires 2 32-bit integer seeds. -// Currently only 'shape's of type F32 are implemented. +// Currently only 'shape's of type F32, S32 and S64 are implemented. XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, XlaOp minval, XlaOp maxval); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8c7df3ff5189c817202eaf39adb572f7e232ec2 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +namespace xla { + +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + + const int64 n_dims = ShapeUtil::Rank(shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return Slice(x, padded_start, padded_end, strides); + }); +} + +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = ConstantR1(builder, start_as_int32); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return DynamicUpdateSlice(x, update, start_constant); + }); +} + +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(x, update, padded_start); + }); +} + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), output.begin()); + std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); + return output; +} + +XlaOp PrependZerosInMajorDims(XlaOp x, absl::Span starts) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + auto zero = Reshape(ConstantR0(builder, 0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = Reshape(starts[i], {1}); + } + return ConcatInDim(builder, padded_starts, 0); + }); +} + +} // namespace + +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes) { + XlaBuilder* builder = x.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); + const int64 n_dims = ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + auto major_dims = AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + auto padded_starts = PrependZerosInMajorDims(x, starts); + auto padded_sizes = ConcatVectors(major_dims, sizes); + return DynamicSlice(x, padded_starts, padded_sizes); + }); +} + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts) { + auto padded_starts = PrependZerosInMajorDims(x, starts); + return DynamicUpdateSlice(x, update, padded_starts); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h new file mode 100644 index 0000000000000000000000000000000000000000..6c482a38b5489c9fb17c3dca9ee3d2a1b8fd1890 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/types.h" + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SLICING_H_ diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d362119e01006555db0f82d02626175936e1d05 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/slicing.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using SlicingTest = xla::ClientLibraryTestBase; + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(SlicingTest, Simple2dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, x, y; + auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); + auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); + DynamicSliceInMinorDims(a, {x, y}, {1, 1}); + + ComputeAndCompareR2(&builder, {{10}}, + {a_data.get(), x_data.get(), y_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(SlicingTest, Simple3dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); + + DynamicSliceInMinorDims(a, {index, xla::ConstantR0(&builder, 0)}, + {1, 4}); + + ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, + {a_data.get(), index_data.get()}); +} + +XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b, x, y; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); + auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); + + DynamicUpdateSliceInMinorDims(a, b, {x, y}); + + xla::Array2D expected( + {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); + + ComputeAndCompareR2( + &builder, expected, + {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc index a904be259a3870a679b2c4699ec01e2a11b1ce46..e8553a08bb014e790822a14e128686b60b8d6b7c 100644 --- a/tensorflow/compiler/xla/client/lib/sorting.cc +++ b/tensorflow/compiler/xla/client/lib/sorting.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -23,13 +25,12 @@ XlaOp TopK(XlaOp input, int64 k) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); int last_dim = input_shape.dimensions_size() - 1; - int last_dim_size = input_shape.dimensions(last_dim); - XlaOp iota_s32 = Iota(builder, S32, last_dim_size); + Shape iota_shape = + ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions())); + XlaOp iota_s32 = Iota(builder, iota_shape, last_dim); auto input_dims = input_shape.dimensions(); - std::vector broadcast_dims(input_dims.begin(), input_dims.end() - 1); - XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims); - XlaOp sort_result = Sort(Neg(input), broadcast_s32); + XlaOp sort_result = Sort(Neg(input), {iota_s32}); std::vector start_indices(input_shape.dimensions_size(), 0); std::vector limit_indices(input_dims.begin(), input_dims.end()); limit_indices[last_dim] = k; diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc index fef98c9923096e21a755c6d730de2c7c10852b2d..27ff36c7491ab8397d46f3a49493ff2b904deb2d 100644 --- a/tensorflow/compiler/xla/client/lib/sorting_test.cc +++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/sorting.h" + +#include + #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -41,6 +44,28 @@ XLA_TEST_F(SortingTest, TopK3From8Indices) { ComputeAndCompareR1(&builder, {0, 1, 2}, {}); } +// TODO(b/119930279): enable this test. +XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + ComputeAndCompareR1(&builder, {2, 1, 0}, {}); +} + +XLA_TEST_F(SortingTest, NOT_TopKFullSortMinInt) { + XlaBuilder builder(TestName()); + auto x_rev = ConstantR1(&builder, {std::numeric_limits::min(), + std::numeric_limits::min() + 1, + std::numeric_limits::max()}); + xla::GetTupleElement(xla::TopK(x_rev, 3), 1); + // TopK currently negates the keys, which doesn't work correctly for + // std::numeric_limits::min(). Therefore, it will sort this key to the + // front instead of to the back. + ComputeAndCompareR1(&builder, {0, 2, 1}, {}); +} + XLA_TEST_F(SortingTest, TopKFullSort) { XlaBuilder builder(TestName()); const int kSize = 16; @@ -56,5 +81,13 @@ XLA_TEST_F(SortingTest, TopKFullSort) { ComputeAndCompareR1(&builder, inputs, {}); } +XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) { + XlaBuilder builder(TestName()); + XlaOp a; + auto a_data = CreateR1Parameter({1, 1, 2, 2, 1}, 0, "a", &builder, &a); + xla::GetTupleElement(xla::TopK(a, 5), 1); + ComputeAndCompareR1(&builder, {2, 3, 0, 1, 4}, {a_data.get()}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index ff0ec76a7f9b62fce0f14beae688cb0dd74847a1..a95bbf2c8c860914877d3195b97342097dafc725 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -66,7 +66,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); - *execution_options.mutable_shape_with_output_layout() = shape; + *execution_options.mutable_shape_with_output_layout() = shape.ToProto(); return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); } @@ -93,13 +93,13 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { - CHECK(computation.proto().has_program_shape()) + CHECK(computation.proto().has_host_program_shape()) << "Computation should have progran shape."; - auto program_shape = computation.proto().program_shape(); + auto program_shape = computation.proto().host_program_shape(); std::vector> results; - for (const Shape& shape : program_shape.parameters()) { - results.push_back(MakeFakeDataOrDie(shape, client)); + for (const ShapeProto& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(Shape(shape), client)); } return results; } diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5a1d34cc66e6f8c1a832f8a8437163b846a5431 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc @@ -0,0 +1,412 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace xla { + +// Get the diagonal blocks of the coefficient matrix +XlaOp DiagonalBlocks(XlaOp a, int64 block_size) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a)); + int ndims = ShapeUtil::Rank(shape); + int64 n = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = n / block_size; + + XlaOp diag_blocks; + + // If the coefficient matrix is exactly the block size, we just add a + // singleton dimension i.e. [..., n, n] -> [..., 1, n, n] + if (n == block_size) { + std::vector permutation(ndims); + std::iota(permutation.begin(), permutation.end(), 1); + permutation.insert(permutation.end() - 2, 0); + return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation); + } + + // We can grab entire blocks using gather + if (n > block_size) { + // Construct the starting indices of the diagonal blocks + auto start_indices = + Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks), + ConstantR0(builder, block_size)), + /*broadcast_sizes=*/{2}), + /*permutation=*/{1, 0}); + + // Gather the diagonal blocks + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims); + dim_numbers.add_start_index_map(ndims - 2); + dim_numbers.add_start_index_map(ndims - 1); + dim_numbers.set_index_vector_dim(1); + diag_blocks = Gather(a, start_indices, dim_numbers, + /*slice_sizes=*/{block_size, block_size}); + } + + // The last block might be smaller than the block size, + // so we will need to pad it + if (n % block_size != 0) { + // Pad with zeros + auto last_blocks = + SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n}); + PaddingConfig config = MakeNoPaddingConfig(ndims); + int64 padding = block_size - n % block_size; + config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding); + config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding); + last_blocks = + Pad(last_blocks, Zero(builder, shape.element_type()), config); + + // Add a singleton dimension + // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size] + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); + auto shape_dims = AsInt64Slice(blocks_shape.dimensions()); + auto last_blocks_dims = std::vector(ndims); + std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); + last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); + last_blocks = Reshape(last_blocks, last_blocks_dims); + + // Concatenate with the other blocks if necessary + if (n > block_size) { + diag_blocks = + ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2); + } else { + diag_blocks = last_blocks; + } + } + + return diag_blocks; + }); +} + +XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a, + bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = diag_blocks.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // Input is a batch of square lower triangular square matrices. Its shape is + // (..., size, size). We resize this to (num_blocks, size, size). + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks)); + int64 block_size = ShapeUtil::GetDimension(shape, -1); + int64 num_blocks = ShapeUtil::ElementsIn(shape) / + tensorflow::MathUtil::IPow(block_size, 2); + diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size}); + + // The input must be triangular because we rely on that when doing + // multiplications later on + diag_blocks = Triangle(diag_blocks, /*lower=*/lower); + + // Rescale blocks to be unit triangular, but avoid dividing by + // zero (which can happen if the last block was padded) otherwise it will + // introduce nans which will propagate + auto diags = GetMatrixDiagonal(diag_blocks); + TF_ASSIGN_OR_RETURN(Shape diags_shape, builder->GetShape(diags)); + auto one = ScalarLike(diags, 1); + auto ones = Broadcast(one, AsInt64Slice(diags_shape.dimensions())); + diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); + auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); + + // We can now use the fact that for an upper triangular matrix + // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have + // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks + // have been rescaled to be unit triangular, so L22 = L22' = 1. + + // Initialize the output matrix with -1s on the diagonal. We use -1 instead + // of 1 because we cannot do matrix-vector multiplies with variable shapes + // inside of a loop, or do irregularly shaped in-place updates. Hence, + // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the + // entire row i.e. we calculate + // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I]) + // which means [L21 L22 0] <- [-L21 * L11', L22, 0]. + auto identity = + IdentityMatrix(builder, shape.element_type(), block_size, block_size); + auto neg_identity = -identity; + + // The first or last diagonal element should be set to 1 instead of -1 + // though, since we never update it + auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1}); + auto start_index = (lower) ? 0 : block_size - 1; + auto output_block = DynamicUpdateSlice( + neg_identity, pos_one, + /*start_indices=*/ConstantR1(builder, 2, start_index)); + + // Broadcast diag([1, -1, -1, ...]) to every block + XlaOp output = Broadcast(output_block, + /*broadcast_sizes=*/{num_blocks}); + + // Now we construct a loop that performs matrix-vector multiplications + // inverting the blocks one row at a time + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + ShapeUtil::MakeShape(S32, {}), + // The output has the shape of A, with one row updated each iteration. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size}), + // The input is a loop invariant. + ShapeUtil::MakeShape(shape.element_type(), + {num_blocks, block_size, block_size})}; + Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes); + + auto init_i = One(builder, S32); + auto init = Tuple(builder, {init_i, output, scaled_diag_blocks}); + + // Construct the loop condition function. + std::unique_ptr condb = + builder->CreateSubBuilder("InvertDiagCond"); + { + auto i = GetTupleElement( + Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0); + Lt(i, ConstantR0(condb.get(), block_size)); + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function. + std::unique_ptr bodyb = + builder->CreateSubBuilder("InvertDiagBody"); + { + auto input_tuple = + Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple"); + + auto i = GetTupleElement(input_tuple, 0); + auto body_out = GetTupleElement(input_tuple, 1); + auto body_input = GetTupleElement(input_tuple, 2); + + auto zero = ConstantR1(bodyb.get(), 1, 0); + auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i; + auto start_indices = + ConcatInDim(bodyb.get(), {zero, Reshape(j, {1}), zero}, 0); + auto input_row = + DynamicSlice(body_input, start_indices, + /*slice_sizes=*/{num_blocks, 1, block_size}); + + // We want -L21 L11^{-1} + DotDimensionNumbers dnums; + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + PrecisionConfig precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); + + body_out = DynamicUpdateSlice(body_out, update, start_indices); + + auto next_i = i + ScalarLike(i, 1); + Tuple(bodyb.get(), {next_i, body_out, body_input}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto invert_while = While(cond, body, init); + auto inv_diag_blocks = GetTupleElement(invert_while, 1); + + // Undo the scaling + inv_diag_blocks = Div(inv_diag_blocks, diags, + /*broadcast_dimensions=*/{0, 1}); + + // Reshape back to original batch major dimensions + return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions())); + }); +} + +XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1); + + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + int64 ndims = ShapeUtil::Rank(a_shape); + int64 n = ShapeUtil::GetDimension(a_shape, -1); + int64 num_blocks = n / block_size + (n % block_size != 0); + int64 m_dim = (left_side) ? -1 : -2; + int64 m = ShapeUtil::GetDimension(b_shape, m_dim); + + // Initialize the solution + auto x = ZerosLike(b); + + // This loop is unrolled for performance reasons, but it could be expressed + // rolled as well since the matrices are of the same size each iteration + for (int i = 0; i < num_blocks; i++) { + // High-level intuition: We have B[i] = L[i] @ X. Since L is upper + // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split + // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which + // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] + + // Decide whether we go from first block to last or vice versa + auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; + + // Get the size of the inverse blocks (the last one might be smaller) + int64 block = (n % block_size != 0 && j + 1 == num_blocks) + ? n % block_size + : block_size; + auto inv_block = + MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0}, + {j + 1, block, block}), + /*dimensions=*/{ndims - 2, ndims - 1}), + conjugate_a); + + // Get the corresponding row of B + int64 k = std::min((j + 1) * block_size, n); + std::vector start = {j * block_size, 0}; + std::vector end = {k, m}; + if (!left_side) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); + } + auto b_row = SliceInMinorDims(b, start, end); + + XlaOp remainder; + if (i == 0) { + remainder = b_row; + } else { + // This matrix multiply involves a lot of multiplying with zero (namely, + // X[i * block_size:] = 0), but this is faster than slicing... + end = {k, n}; + if (!left_side) { + std::swap(end[0], end[1]); + } + if (transpose_a) { + std::swap(start[0], start[1]); + std::swap(end[0], end[1]); + } + auto a_row = + MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); + if (left_side) { + remainder = + b_row - BatchDot(MaybeTransposeInMinorDims(a_row, transpose_a), x, + precision); + } else { + remainder = + b_row - BatchDot(x, MaybeTransposeInMinorDims(a_row, transpose_a), + precision); + } + } + + XlaOp x_update; + auto zero = Zero(builder, S32); + auto start_index = ConstantR0WithType(builder, S32, j * block_size); + std::vector update_starts = {start_index, zero}; + if (left_side) { + x_update = BatchDot(MaybeTransposeInMinorDims(inv_block, transpose_a), + remainder, precision); + } else { + x_update = BatchDot(remainder, + MaybeTransposeInMinorDims(inv_block, transpose_a), + precision); + std::swap(update_starts[0], update_starts[1]); + } + x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); + } + + return x; + }); +} + +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool transpose_a, bool conjugate_a, int64 block_size, + PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b)); + if (ShapeUtil::Rank(a_shape) != ShapeUtil::Rank(b_shape)) { + return InvalidArgument( + "Arguments to TriangularSolve have shapes with different ranks: " + "%s vs. %s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); + } + const int64 ndims = ShapeUtil::Rank(a_shape); + if (ndims < 2) { + return InvalidArgument( + "Arguments to TriangularSolve was rank %d but must have rank >= 2.", + ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); + if (a_size != b_size) { + return InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal; " + "shapes were %s and %s.", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); + } + batch_dimensions.push_back(a_size); + } + + if (ShapeUtil::GetDimension(a_shape, -1) != + ShapeUtil::GetDimension(a_shape, -2)) { + return InvalidArgument( + "The 'a' argument to TriangularSolve must be a batched square matrix;" + " shape was: %s", + ShapeUtil::HumanString(a_shape)); + } + const int64 m = ShapeUtil::GetDimension(b_shape, -2); + const int64 n = ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) { + return InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes %s and " + "%s", + ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape)); + } + + if (block_size < 1) { + return InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got %d", + block_size); + } + + // We find the diagonal blocks of the coefficient matrix + auto diag_blocks = DiagonalBlocks(a, block_size); + + // We invert these blocks in parallel using batched matrix-vector products + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, + conjugate_a, precision); + + // We now find the solution using GEMMs + auto x = + SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, + transpose_a, conjugate_a, precision); + + return x; + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h similarity index 88% rename from tensorflow/compiler/tf2xla/lib/triangular_solve.h rename to tensorflow/compiler/xla/client/lib/triangular_solve.h index 2303234f361e54cd2a0ad495cb03b371bed76877..50a3b30ebd1c15eb6d2ace4e351cb41f21db7093 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -namespace tensorflow { +namespace xla { // Solves systems of linear equations with lower or upper triangular coefficient // matrices by forward- or back-substitution. Broadcasting along leading @@ -57,11 +57,11 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve( - xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, +XlaOp TriangularSolve( + XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size = 128, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); -} // namespace tensorflow +} // namespace xla -#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc similarity index 99% rename from tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc rename to tensorflow/compiler/xla/client/lib/triangular_solve_test.cc index aeebf16028d40189203cdfd815f06a339ee72902..f6a70d64a788d95a456774ccbbcf67f2e5cac98b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/xla/client/lib/triangular_solve.h" #include #include @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace tensorflow { +namespace xla { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; @@ -330,4 +330,4 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } } // namespace -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index f96b6c9c261a9686fb647e3da0dcc933cd1f70df..049cd15738a619294b19d5cf74ca514d7b4a00ad 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( + ShapeUtil::HumanStringWithLayout( computation_layout.parameter_layout(i).shape()), - ShapeUtil::HumanString(arguments[i]->on_host_shape())); + ShapeUtil::HumanStringWithLayout(arguments[i]->on_host_shape())); } } @@ -310,4 +310,28 @@ StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); } +StatusOr LocalClient::TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal) { + const ::xla::Shape& shape = literal.shape(); + + TF_ASSIGN_OR_RETURN( + ::xla::ScopedShapedBuffer shaped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + shape, backend().memory_allocator(), device_oridinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_oridinal)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.get(), literal, shaped_buffer)); + std::vector<::xla::ScopedShapedBuffer> replicated_buffer; + replicated_buffer.emplace_back(std::move(shaped_buffer)); + ::xla::TransferToServerResponse result; + TF_ASSIGN_OR_RETURN(*result.mutable_data(), + local_service_->RegisterReplicatedBuffers( + std::move(replicated_buffer), + absl::StrCat("TransferToServer literal of shape ", + ::xla::ShapeUtil::HumanString(shape)))); + + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9..ddb36680e8b185b053368baffa6f1d5cac50dc07 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -60,8 +60,8 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. // - // The given ExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. Status ValidateExecutionOptions( const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); @@ -69,8 +69,8 @@ class LocalExecutable { // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. // - // The given ServiceExecutableRunOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ServiceExecutableRunOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const absl::Span arguments); @@ -114,8 +114,8 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // The given ExecutableBuildOptions override any values from legacy_flags - // (TF_XLA_FLAGS environment variable). + // The given ExecutableBuildOptions override any values from TF_XLA_FLAGS + // environment variable. StatusOr> Compile( const XlaComputation& computation, const absl::Span argument_layouts, @@ -129,6 +129,10 @@ class LocalClient : public Client { const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); + // Transfer the BorrowingLiteral to the device with the given ordinal. + StatusOr TransferToLocalServer( + const ::xla::BorrowingLiteral& literal, int device_oridinal); + // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc index 176802b33ef824a1f898255a19e44def3c1fc982..fb9ea6ec3fc41d5e04ca125798a8199350470a44 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.cc +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); std::vector dimensions(1, num_tiles); - *result.mutable_tile_shape() = tile_shape; + *result.mutable_tile_shape() = tile_shape.ToProto(); auto& tile_dimension = (*result.mutable_tile_shape()->mutable_dimensions())[0]; tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index d196252db16fe84d44824856a2202c1a5d3fce95..60df2ec3959216b0564846ad47c21c5bcc01ea57 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -34,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/mutex.h" namespace xla { @@ -42,12 +40,30 @@ using absl::StrCat; namespace { -int64 GetUniqueId() { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static int64 built_counter = 0; - tensorflow::mutex_lock loc(mu); - const int64 id = built_counter++; - return id; +static const char kNameSeparator = '.'; + +// Retrieves the base name of an instruction or computation fully qualified +// name, using separator as boundary between the initial base name part, and +// the numeric identification. +string GetBaseName(const string& name, char separator) { + auto pos = name.rfind(separator); + CHECK_NE(pos, string::npos) << name; + return name.substr(0, pos); +} + +// Generates a fully qualified computation/instruction name. +string GetFullName(const string& base_name, char separator, int64 id) { + const char separator_str[] = {separator, '\0'}; + return StrCat(base_name, separator_str, id); +} + +// Common function to standardize setting name and IDs on computation and +// instruction proto entities. +template +void SetProtoIdAndName(T* entry, const string& base_name, char separator, + int64 id) { + entry->set_id(id); + entry->set_name(GetFullName(base_name, separator, id)); } } // namespace @@ -86,7 +102,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); - return instr->shape(); + return Shape(instr->shape()); } StatusOr> XlaBuilder::GetOperandShapes( @@ -139,7 +155,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { ProgramShape program_shape; - *program_shape.mutable_result() = root_proto->shape(); + *program_shape.mutable_result() = Shape(root_proto->shape()); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -156,7 +172,7 @@ StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { const int64 index = instr.parameter_number(); TF_RET_CHECK(index >= 0 && index < param_count) << "invalid parameter number: " << index; - *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameters(index) = Shape(instr.shape()); *program_shape.mutable_parameter_names(index) = instr.name(); } } @@ -223,6 +239,19 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, visited->insert(op_handle); } +Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, + int64 target_dim_num) { + TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind( + DynamicParameterBinding::DynamicParameter{dynamic_size_param_num, + dynamic_size_param_index}, + DynamicParameterBinding::DynamicDimension{ + target_param_num, target_param_index, target_dim_num})); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -258,17 +287,15 @@ StatusOr XlaBuilder::Build(int64 root_id) { } HloComputationProto entry; - entry.set_id(GetUniqueId()); // Give the computation a global unique id. - entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. - - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); + *entry.mutable_program_shape() = program_shape.ToProto(); entry.set_root_id(root_id); for (auto& instruction : instructions_) { // Ensures that the instruction names are unique among the whole graph. - const string& new_name = - StrCat(instruction.name(), ".", entry.id(), ".", instruction.id()); - instruction.set_name(new_name); + instruction.set_name( + GetFullName(instruction.name(), kNameSeparator, instruction.id())); entry.add_instructions()->Swap(&instruction); } @@ -278,12 +305,15 @@ StatusOr XlaBuilder::Build(int64 root_id) { module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = entry.program_shape(); + *module->mutable_host_program_shape() = entry.program_shape(); for (auto& e : embedded_) { module->add_computations()->Swap(&e.second); } module->add_computations()->Swap(&entry); + *(module->mutable_dynamic_parameter_binding()) = + dynamic_parameter_binding_.ToProto(); + // Clear data held by this builder. this->instructions_.clear(); this->handle_to_index_.clear(); @@ -299,7 +329,7 @@ StatusOr XlaBuilder::InDimBroadcast( TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : broadcast_dimensions) { instr.add_dimensions(dim); } @@ -350,8 +380,9 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferUnaryOpShape(unop, operand_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), unop, {operand}); }); } @@ -362,9 +393,10 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); + *instr.mutable_shape() = shape.ToProto(); const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); @@ -378,7 +410,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; std::vector to_size; - for (int64 size : instr.shape().dimensions()) { + for (int64 size : shape.dimensions()) { to_size.push_back(size); } for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); @@ -398,14 +430,14 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, } TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), updated_lhs)); + AddBroadcastSequence(shape, updated_lhs)); } TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); - if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), updated_rhs)); + AddBroadcastSequence(shape, updated_rhs)); } return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); @@ -419,30 +451,28 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferTernaryOpShape( - triop, lhs_shape, rhs_shape, ehs_shape)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape, + rhs_shape, ehs_shape)); + *instr.mutable_shape() = shape.ToProto(); XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; - if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(shape)) { if (!ShapeUtil::IsTuple(lhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + !ShapeUtil::SameDimensions(shape, lhs_shape)) { // lhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_lhs, - AddBroadcastSequence(instr.shape(), lhs)); + TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs)); } if (!ShapeUtil::IsTuple(rhs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + !ShapeUtil::SameDimensions(shape, rhs_shape)) { // rhs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_rhs, - AddBroadcastSequence(instr.shape(), rhs)); + TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs)); } if (!ShapeUtil::IsTuple(ehs_shape) && - !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + !ShapeUtil::SameDimensions(shape, ehs_shape)) { // ehs is being implicitly broadcasted. Change to explicit. - TF_ASSIGN_OR_RETURN(updated_ehs, - AddBroadcastSequence(instr.shape(), ehs)); + TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs)); } } return AddInstruction(std::move(instr), triop, @@ -463,7 +493,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = literal.shape(); + *instr.mutable_shape() = literal.shape().ToProto(); *instr.mutable_literal() = literal.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConstant); }); @@ -472,7 +502,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(iota_dimension); return AddInstruction(std::move(instr), HloOpcode::kIota); }); @@ -492,10 +522,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCallShape(operand_shape_ptrs, - /*to_apply=*/called_program_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape( + operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); @@ -513,7 +543,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, } instr.set_parameter_number(parameter_number); instr.set_name(name); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kParameter); }); } @@ -543,10 +573,35 @@ XlaOp XlaBuilder::Broadcast(const XlaOp& operand, } XlaOp XlaBuilder::BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { - return InDimBroadcast(shape, operand, broadcast_dimensions); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + // Output shape, in the case of degenerate broadcast, the out_dim_size is + // not necessarily the same as the dimension sizes of the output shape. + const auto& output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + + TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape( + operand_shape, output_shape, broadcast_dimensions) + .status()); + std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); + for (int i = 0; i < broadcast_dimensions.size(); i++) { + in_dim_size[broadcast_dimensions[i]] = operand_shape.dimensions(i); + } + const auto& in_dim_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), in_dim_size); + TF_ASSIGN_OR_RETURN( + XlaOp in_dim_broadcast, + InDimBroadcast(in_dim_shape, operand, broadcast_dimensions)); + + // If broadcast is not degenerate, return broadcasted result. + if (ShapeUtil::Equal(in_dim_shape, output_shape)) { + return in_dim_broadcast; + } + + // Otherwise handle degenerate broadcast case. + return AddBroadcastSequence(output_shape, in_dim_broadcast); }); } @@ -554,7 +609,7 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); } @@ -566,9 +621,9 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferSliceShape(operand_shape, start_indices, - limit_indices, strides)); + Shape shape, ShapeInference::InferSliceShape( + operand_shape, start_indices, limit_indices, strides)); + *instr.mutable_shape() = shape.ToProto(); for (int i = 0; i < start_indices.size(); i++) { auto* slice_config = instr.add_slice_dimensions(); slice_config->set_start(start_indices[i]); @@ -603,9 +658,10 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicSliceShape( operand_shape, start_indices_shape, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); for (int64 size : slice_sizes) { instr.add_dynamic_slice_sizes(size); @@ -625,9 +681,10 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDynamicUpdateSliceShape( operand_shape, update_shape, start_indices_shape)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, {operand, update, start_indices}); @@ -643,9 +700,9 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span operands, TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( + operand_shape_ptrs, dimension)); + *instr.mutable_shape() = shape.ToProto(); instr.add_dimensions(dimension); @@ -662,10 +719,9 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape, GetShape(padding_value)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferPadShape(operand_shape, padding_value_shape, - padding_config)); - + Shape shape, ShapeInference::InferPadShape( + operand_shape, padding_value_shape, padding_config)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_padding_config() = padding_config; return AddInstruction(std::move(instr), HloOpcode::kPad, @@ -678,7 +734,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& shape, + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); XlaOp transposed = IsIdentityPermutation(dimensions) @@ -691,7 +747,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -741,7 +797,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); @@ -767,9 +823,10 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); }); } @@ -784,7 +841,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = - ShapeUtil::GetTupleElementShape(tuple_shape, index); + ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto(); instr.set_tuple_index(index); @@ -843,9 +900,10 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_dot_dimension_numbers() = dimension_numbers; if (precision_config != nullptr) { *instr.mutable_precision_config() = *precision_config; @@ -987,10 +1045,11 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, feature_group_count, instr.window(), dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); @@ -1063,10 +1122,9 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); - + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape( + operand_shape, fft_type, fft_length)); + *instr.mutable_shape() = shape.ToProto(); instr.set_fft_type(fft_type); for (int64 i : fft_length) { instr.add_fft_length(i); @@ -1084,7 +1142,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1105,7 +1163,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { XlaOp token; auto make_token = [&]() { HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); }; if (sharding()) { @@ -1144,7 +1202,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto infeed_data; - *infeed_data.mutable_shape() = shape; + *infeed_data.mutable_shape() = shape.ToProto(); infeed_data.set_tuple_index(0); return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement, {infeed}); @@ -1160,7 +1218,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape, } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); instr.set_infeed_config(config); if (ShapeUtil::IsArray(shape) && sharding() && @@ -1185,7 +1243,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1198,14 +1256,14 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); // Outfeed takes a token as its second operand. Generate the token to pass // to the outfeed. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -1219,7 +1277,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto tuple_instr; - *tuple_instr.mutable_shape() = ShapeUtil::MakeNil(); + *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto(); // The dummy tuple should have no sharding. { @@ -1238,7 +1296,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { @@ -1251,7 +1309,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout; + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); instr.set_outfeed_config(outfeed_config); @@ -1263,7 +1321,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp XlaBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll); }); } @@ -1273,15 +1331,25 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); } + for (int i = 0; i < tokens.size(); ++i) { + const XlaOp& operand = tokens[i]; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::IsToken(operand_shape)) { + return InvalidArgument( + "All operands to AfterAll must be tokens; operand %d has shape %s", + i, ShapeUtil::HumanString(operand_shape)); + } + } HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens); }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - absl::Span operands, - const Shape& shape, const string& opaque) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1290,9 +1358,34 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, "are reserved for internal use.", call_target_name); } - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -1443,9 +1536,9 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferTransposeShape(operand_shape, permutation)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape( + operand_shape, permutation)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : permutation) { instr.add_dimensions(dim); } @@ -1458,9 +1551,9 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReverseShape(operand_shape, dimensions)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape( + operand_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions) { instr.add_dimensions(dim); } @@ -1468,30 +1561,28 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, +XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span values, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); operand_shape_ptrs.push_back(&keys_shape); - Shape values_shape; - if (values.has_value()) { - TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values)); - operand_shape_ptrs.push_back(&values_shape); - } - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferVariadicOpShape( - HloOpcode::kSort, operand_shape_ptrs)); + TF_ASSIGN_OR_RETURN(std::vector values_shapes, + GetOperandShapes(values)); + absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, operand_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); if (dimension == -1) { TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys)); dimension = ShapeUtil::Rank(keys_shape) - 1; } instr.add_dimensions(dimension); - return values.has_value() - ? AddInstruction(std::move(instr), HloOpcode::kSort, - {keys, *values}) - : AddInstruction(std::move(instr), HloOpcode::kSort, {keys}); + std::vector operands{keys}; + operands.insert(operands.end(), values.begin(), values.end()); + return AddInstruction(std::move(instr), HloOpcode::kSort, operands); }); } @@ -1505,9 +1596,9 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); }); } @@ -1517,9 +1608,9 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvertShape(operand_shape, new_element_type)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( + operand_shape, new_element_type)); + *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, {operand}); }); @@ -1551,11 +1642,11 @@ XlaOp XlaBuilder::Map(absl::Span operands, TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, - dimensions)); + Shape shape, ShapeInference::InferMapShape( + operand_shape_ptrs, called_program_shape, dimensions)); + *instr.mutable_shape() = shape.ToProto(); - const Shape& output_shape = instr.shape(); + Shape output_shape(instr.shape()); const int64 output_rank = ShapeUtil::Rank(output_shape); AddCalledComputation(computation, &instr); std::vector new_operands(operands.begin(), operands.end()); @@ -1598,7 +1689,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - *instr.mutable_shape() = shape; + *instr.mutable_shape() = shape.ToProto(); instr.set_distribution(distribution); @@ -1626,10 +1717,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, condition.GetProgramShape()); TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferWhileShape(condition_program_shape, - body_program_shape, init_shape)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape( + condition_program_shape, + body_program_shape, init_shape)); + *instr.mutable_shape() = shape.ToProto(); // Body comes before condition computation in the vector. AddCalledComputation(body, &instr); AddCalledComputation(condition, &instr); @@ -1646,10 +1737,10 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, GetShape(start_indices)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, start_indices_shape, + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape( + input_shape, start_indices_shape, dimension_numbers, slice_sizes)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_gather_dimension_numbers() = dimension_numbers; for (int64 bound : slice_sizes) { @@ -1674,10 +1765,11 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, update_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferScatterShape( input_shape, scatter_indices_shape, updates_shape, to_apply_shape, dimension_numbers)); + *instr.mutable_shape() = shape.ToProto(); *instr.mutable_scatter_dimension_numbers() = dimension_numbers; @@ -1704,10 +1796,11 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, false_computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferConditionalShape( predicate_shape, true_operand_shape, false_operand_shape, true_computation_shape, false_computation_shape)); + *instr.mutable_shape() = shape.ToProto(); // The index of true_computation must be 0 and that of false computation // must be 1. @@ -1749,9 +1842,10 @@ XlaOp XlaBuilder::Reduce(absl::Span operands, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferReduceShape( operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1789,9 +1883,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, std::vector> padding_values = MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } @@ -1800,6 +1894,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1810,11 +1906,12 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferReduceWindowShape(operand_shape, init_shape, - instr.window(), to_apply_shape)); + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( + operand_shape, init_shape, + instr.window(), to_apply_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, @@ -1832,9 +1929,10 @@ XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferBatchNormTrainingShape( operand_shape, scale_shape, offset_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1856,10 +1954,11 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean)); TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferBatchNormInferenceShape( - operand_shape, scale_shape, offset_shape, - mean_shape, variance_shape, feature_index)); + TF_ASSIGN_OR_RETURN( + Shape shape, ShapeInference::InferBatchNormInferenceShape( + operand_shape, scale_shape, offset_shape, mean_shape, + variance_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1881,10 +1980,11 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean)); TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var)); TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBatchNormGradShape( operand_shape, scale_shape, batch_mean_shape, batch_var_shape, grad_output_shape, feature_index)); + *instr.mutable_shape() = shape.ToProto(); instr.set_epsilon(epsilon); instr.set_feature_index(feature_index); @@ -1915,9 +2015,9 @@ XlaOp XlaBuilder::CrossReplicaSum( return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( + {&operand_shape})); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; @@ -1970,8 +2070,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); + *instr.mutable_shape() = shape.ToProto(); for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; } @@ -1996,8 +2096,9 @@ XlaOp XlaBuilder::CollectivePermute( TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); HloInstructionProto instr; TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), + Shape shape, ShapeInference::InferCollectivePermuteShape(operand_shape)); + *instr.mutable_shape() = shape.ToProto(); for (const auto& pair : source_target_pairs) { auto* proto_pair = instr.add_source_target_pairs(); @@ -2046,10 +2147,11 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSelectAndScatterShape( operand_shape, select_shape, instr.window(), source_shape, init_shape, scatter_shape)); + *instr.mutable_shape() = shape.ToProto(); AddCalledComputation(select, &instr); AddCalledComputation(scatter, &instr); @@ -2064,9 +2166,10 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReducePrecisionShape( operand_shape, exponent_bits, mantissa_bits)); + *instr.mutable_shape() = shape.ToProto(); instr.set_exponent_bits(exponent_bits); instr.set_mantissa_bits(mantissa_bits); return AddInstruction(std::move(instr), HloOpcode::kReducePrecision, @@ -2081,7 +2184,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2100,15 +2203,17 @@ XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token, // token}. HloInstructionProto send_instr; TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp send, AddInstruction(std::move(send_instr), HloOpcode::kSend, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, {send}); @@ -2122,7 +2227,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {})); @@ -2133,7 +2238,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { // TODO(b/80000000): Remove this when clients have been updated to handle // tokens. HloInstructionProto recv_data; - *recv_data.mutable_shape() = shape; + *recv_data.mutable_shape() = shape.ToProto(); recv_data.set_tuple_index(0); return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement, {recv}); @@ -2150,15 +2255,18 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), HloOpcode::kRecv, {token})); HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, {recv}); @@ -2192,9 +2300,11 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, // Send instruction produces a tuple of {aliased operand, U32 context, // token}. HloInstructionProto send_instr; - *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape_with_layout, ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()}); + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape_with_layout, + ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}) + .ToProto(); send_instr.set_channel_id(handle.handle()); send_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp send, @@ -2202,7 +2312,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, {operand, token})); HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); send_done_instr.set_channel_id(handle.handle()); send_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, @@ -2231,8 +2341,10 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, // Recv instruction produces a tuple of {receive buffer, U32 context, // token}. HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}); + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_instr.set_channel_id(handle.handle()); recv_instr.set_is_host_transfer(true); TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), @@ -2240,7 +2352,8 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, HloInstructionProto recv_done_instr; *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); recv_done_instr.set_channel_id(handle.handle()); recv_done_instr.set_is_host_transfer(true); return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, @@ -2248,6 +2361,19 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, }); } +XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape( + operand_shape, dimension)); + *instr.mutable_shape() = shape.ToProto(); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize, + {operand}); + }); +} + StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { TF_RETURN_IF_ERROR(first_error_); @@ -2261,7 +2387,7 @@ StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { } StatusOr XlaBuilder::BuildConstantSubGraph( - const XlaOp& root_op) const { + const XlaOp& root_op) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); if (!is_constant) { auto op_status = LookUpInstruction(root_op); @@ -2283,10 +2409,10 @@ StatusOr XlaBuilder::BuildConstantSubGraph( LookUpInstruction(root_op)); HloComputationProto entry; - entry.set_id(GetUniqueId()); // Give the computation a global unique id. - entry.set_name(StrCat(name_, entry.id(), "_compute_constant")); + SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator, + GetNextId()); entry.set_root_id(root->id()); - ProgramShape* program_shape = entry.mutable_program_shape(); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is @@ -2330,7 +2456,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( module->set_id(entry.id()); module->set_entry_computation_name(entry.name()); module->set_entry_computation_id(entry.id()); - *module->mutable_program_shape() = *program_shape; + *module->mutable_host_program_shape() = *program_shape; for (auto& e : embedded_) { if (related_calls.find(e.second.id()) != related_calls.end()) { *module->add_computations() = e.second; @@ -2424,7 +2550,7 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = GetUniqueId(); + const int64 handle = GetNextId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { @@ -2455,9 +2581,50 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, void XlaBuilder::AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr) { - instr->add_called_computation_ids(computation.proto().entry_computation_id()); + absl::flat_hash_map remapped_ids; + std::vector imported_computations; + imported_computations.reserve(computation.proto().computations_size()); + // Before we import the computations by remapping IDs, and capturing the + // old->new mappings in remapped_ids. for (const HloComputationProto& e : computation.proto().computations()) { - embedded_.insert({e.id(), e}); + HloComputationProto new_computation(e); + int64 computation_id = GetNextId(); + remapped_ids[new_computation.id()] = computation_id; + SetProtoIdAndName(&new_computation, + GetBaseName(new_computation.name(), kNameSeparator), + kNameSeparator, computation_id); + for (auto& instruction : *new_computation.mutable_instructions()) { + int64 instruction_id = GetNextId(); + remapped_ids[instruction.id()] = instruction_id; + SetProtoIdAndName(&instruction, + GetBaseName(instruction.name(), kNameSeparator), + kNameSeparator, instruction_id); + } + new_computation.set_root_id(remapped_ids.at(new_computation.root_id())); + + imported_computations.push_back(std::move(new_computation)); + } + // Once we have imported all the computations, and captured all the ID + // mappings, we go back and fixup the IDs in the imported computations. + instr->add_called_computation_ids( + remapped_ids.at(computation.proto().entry_computation_id())); + for (auto& imported_computation : imported_computations) { + for (auto& instruction : *imported_computation.mutable_instructions()) { + for (auto& operand_id : *instruction.mutable_operand_ids()) { + operand_id = remapped_ids.at(operand_id); + } + for (auto& control_predecessor_id : + *instruction.mutable_control_predecessor_ids()) { + control_predecessor_id = remapped_ids.at(control_predecessor_id); + } + for (auto& called_computation_id : + *instruction.mutable_called_computation_ids()) { + called_computation_id = remapped_ids.at(called_computation_id); + } + } + + int64 computation_id = imported_computation.id(); + embedded_.insert({computation_id, std::move(imported_computation)}); } } @@ -2506,9 +2673,10 @@ XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions) { - return operand.builder()->BroadcastInDim(operand, shape, + return operand.builder()->BroadcastInDim(operand, out_dim_size, broadcast_dimensions); } @@ -2687,7 +2855,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque) { - return builder->CustomCall(call_target_name, operands, shape, opaque); + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, @@ -2800,10 +2977,12 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, - padding); + base_dilations, window_dilations, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, @@ -2914,8 +3093,8 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension) { - return keys.builder()->Sort(keys, std::move(values), dimension); +XlaOp Sort(const XlaOp& keys, absl::Span values, int64 dimension) { + return keys.builder()->Sort(keys, values, dimension); } XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { @@ -3049,4 +3228,8 @@ XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { return builder->Iota(shape, iota_dimension); } +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) { + return operand.builder()->GetDimensionSize(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index cd0d5ca5d3043ca13bbfda40eacc04b86659a85c..098efb60f9bdca8306ff771a505f4a225dea9f7d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -216,7 +217,7 @@ class XlaBuilder { // compile-time constant (see `IsConstant`), returns an error. // // This will copy the needed ops/computations to the subgraph. - StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + StatusOr BuildConstantSubGraph(const XlaOp& root_op); // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous @@ -263,35 +264,30 @@ class XlaBuilder { // evaluating the computation. StatusOr IsConstant(const XlaOp& operand) const; + // Sets up binding which indicates that the `target_dim_num` in the subshape + // `target_param_index` of parameter `target_param_num` is a dynamic dimension + // and its real dynamic size is represented by `dynamic_param_index` in + // parameter `dynamic_param_num`. + // + // TODO(b/119520625): Remove this API once we have more dynamic shape infra + // ready. + Status SetDynamicBinding(int64 dynamic_size_param_num, + ShapeIndex dynamic_size_param_index, + int64 target_param_num, + ShapeIndex target_param_index, int64 target_dim_num); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id); - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. + // Description for the methods below can be found in the corresponding public + // functions section in this file. + XlaOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - // Enqueues a constant with the value of the given literal onto the - // computation. XlaOp ConstantLiteral(const LiteralSlice& literal); - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. template XlaOp ConstantR0(NativeT value); template @@ -321,198 +317,79 @@ class XlaBuilder { template XlaOp ConstantR4FromArray4D(const Array4D& values); - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. template XlaOp ConstantR1(int64 length, NativeT value); - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); - // Performs in-dimension-style broadcast. - // - // Operand specifies the input to be broadcast. "shape" is expected output - // shape. "broadcast_dimensions" are the dimensions to be broadcasting into. - // Dimension numbers in broadcast_dimensions map to individual dimensions - // of the operand, and specify what dimension of the output shape they - // should be broadcast. - // e.g. - // Say operand = [1, 2], i.e., a 1D tensor with 2 elements. - // and dimension of shape is [2,2]. - // Specifying {1} as brodcast_dimension will generate output - // [1 , 2] - // [1 , 2] - // On the other hand, specifying {0} as broadcast_dimension - // will generate output - // [1 , 1] - // [2 , 2] - XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, absl::Span new_sizes); - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice XlaOp Slice(const XlaOp& operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, absl::Span slice_sizes); - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. XlaOp ConcatInDim(absl::Span operands, int64 dimension); - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. void Trace(const string& tag, const XlaOp& operand); - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - // Enqueues a tuple-creation instruction onto the computation. XlaOp Tuple(absl::Span elements); - // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); - // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, const PrecisionConfig* precision_config = nullptr); - // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -520,8 +397,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, @@ -529,8 +404,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -538,8 +411,6 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, @@ -549,79 +420,53 @@ class XlaBuilder { int64 feature_group_count = 1, const PrecisionConfig* precision_config = nullptr); - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape, const string& config = ""); - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const string& outfeed_config); - // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, absl::Span operands); - // Enqueues a custom call instruction onto the computation. - XlaOp CustomCall(const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions = {}); - // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); - // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); @@ -640,81 +485,48 @@ class XlaBuilder { XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Reduces several arrays simultaneously among the provided dimensions, given - // "computation" as a reduction operator. XlaOp Reduce(absl::Span operands, absl::Span init_values, const XlaComputation& computation, absl::Span dimensions_to_reduce); - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, Padding padding); - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); - // Returns the sum of the operand value within each subgroup of replicas. All - // replicas supply one input to the sum and all replicas receive the resulting - // sum for each subgroup. XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups = {}); - // Enqueues an operation that do an AllReduce of the operand cross cores. Here - // AllReduce means doing a reduction on the input operand cross cores and then - // broadcasting the reduction result to those cores. The reduction function is - // defined by `computation`, which should be a commutative computation on - // scalars, e.g., add, min, or max. The way that AllReduce is applied is - // configured by: - // - // - `replica_groups`: each ReplicaGroup contains a list of replica id. If - // empty, all replicas belong to one group. Allreduce will be applied within - // subgroups. For example, we have 4 replicas, then - // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, - // replica 1 and 3 are in subgroup 1. - // - // - `channel_id`: for Allreduce nodes from different modules, if they have - // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will - // not be applied cross modules. - // - // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, const absl::optional& channel_id = absl::nullopt); - // Enqueues an operation that do an Alltoall of the operand cross cores. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - // Enqueues an operation that do an CollectivePermute of the operand cross - // cores. XlaOp CollectivePermute( const XlaOp& operand, const std::vector>& source_target_pairs); - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, absl::Span window_strides, @@ -722,8 +534,6 @@ class XlaBuilder { const XlaOp& init_value, const XlaComputation& scatter); - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, absl::Span window_dimensions, @@ -731,222 +541,126 @@ class XlaBuilder { absl::Span> padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); - // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); - // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, absl::Span broadcast_dimensions = {}); - // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); - // Enqueues an expm1 instruction onto the computation. XlaOp Expm1(const XlaOp& operand); - // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); - // Enqueues a ceil instruction onto the computation. XlaOp Ceil(const XlaOp& operand); - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. XlaOp Round(const XlaOp& operand); - // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); - // Enqueues an log1p instruction (log(x+1)) onto the computation. XlaOp Log1p(const XlaOp& operand); - // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); - // Enqueues a count leading zeros instruction onto the computation. XlaOp Clz(const XlaOp& operand); - // Enqueues a cosine instruction onto the computation. XlaOp Cos(const XlaOp& operand); - // Enqueues a sine instruction onto the computation. XlaOp Sin(const XlaOp& operand); - // Enqueues a tanh instruction onto the computation. XlaOp Tanh(const XlaOp& operand); - // Enqueues a real-part instruction onto the computation. XlaOp Real(const XlaOp& operand); - // Enqueues an imaginary-part instruction onto the computation. XlaOp Imag(const XlaOp& operand); - // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. XlaOp IsFinite(const XlaOp& operand); - // Enqueues an iota operation onto the computation. XlaOp Iota(const Shape& shape, int64 iota_dimension); - // Enqueues a rank-1 iota operation onto the computation. XlaOp Iota(PrimitiveType type, int64 size); - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); - // Enqueues a negate instruction onto the computation. XlaOp Neg(const XlaOp& operand); - // Enqueues a transpose instruction onto the computation. XlaOp Transpose(const XlaOp& operand, absl::Span permutation); - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - // Enqueues a sort (as increasing order) instruction onto the computation. - // If only keys are provided: - // * If the keys are an rank-1 tensor (an array), the result is a sorted array - // of keys, in ascending order. - // * If the keys have higher rank, the keys are sorted along the provided - // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension - // value of 0 will indepenently sort every column, and a dimension value of 1 - // will independently sort each row. If no dimension number is provided, then - // the last dimension is chosen by default. - // - // If both keys and values are provided: - // * The keys and the values must tensors with the same dimensions. The - // element types of the tensors may be different. - // * The result is a tuple that consists of a sorted tensor of keys (along the - // provided dimension, as above) as the first element, and a tensor with their - // corresponding values as the second element. - XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, + XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); - // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - // Enqueues a map instruction onto the computation. XlaOp Map(absl::Span operands, const XlaComputation& computation, absl::Span dimensions, absl::Span static_operands = {}); - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); - // Enqueues a while node onto the computation. XlaOp While(const XlaComputation& condition, const XlaComputation& body, const XlaOp& init); - // Enqueues a conditional node onto the computation. XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, const XlaComputation& false_computation); - // Enqueues a ReducePrecision node onto the computation. XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - // Enqueues a Gather node onto the computation. XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes); - // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, const ScatterDimensionNumbers& dimension_numbers); - // Enqueues a Send node onto the computation for device-to-device - // communication, to send the given operand to a Recv instruction that shares - // the same channel handle. void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, const ChannelHandle& handle); - // Enqueues a Send node which sends data to the host. XlaOp SendToHost(const XlaOp& operand, const XlaOp& token, const Shape& shape_with_layout, const ChannelHandle& handle); - // Enqueues a Recv node which receives data from the host. XlaOp RecvFromHost(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp CreateToken(); - // Enqueues an AfterAll operation with no operands producing a token-shaped - // value. XlaOp AfterAll(absl::Span tokens); - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. XlaOp Recv(const Shape& shape, const ChannelHandle& handle); XlaOp RecvWithToken(const XlaOp& token, const Shape& shape, const ChannelHandle& handle); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, float epsilon, int64 feature_index); - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, const XlaOp& offset, const XlaOp& mean, const XlaOp& variance, float epsilon, int64 feature_index); - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& batch_mean, const XlaOp& batch_var, const XlaOp& grad_output, float epsilon, int64 feature_index); + XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1013,8 +727,14 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation) const; + int64 GetNextId() { return ++next_id_; } + string name_; // Name to use for the built computation. + // The next sequential ID for every instruction/computation contained within + // this computation. + int64 next_id_ = 0; + // The first error encountered while building the computation. // This is OK until the first error is encountered. Status first_error_; @@ -1025,6 +745,9 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. absl::flat_hash_map handle_to_index_; @@ -1102,7 +825,7 @@ class XlaBuilder { absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, + const XlaOp& operand, const absl::Span out_dim_size, const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, @@ -1195,6 +918,10 @@ class XlaBuilder { friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1245,6 +972,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); friend XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span replica_groups); @@ -1302,7 +1031,8 @@ class XlaBuilder { friend XlaOp Transpose(const XlaOp& operand, absl::Span permutation); friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); - friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); + friend XlaOp Sort(const XlaOp& keys, absl::Span values, + int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, @@ -1356,6 +1086,8 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1390,6 +1122,7 @@ class XlaScopedShardingAssignment { // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. +// // Enqueues a "retrieve parameter value" instruction for a parameter that was // passed to the computation. @@ -1470,24 +1203,23 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); -// Performs in-dimension-style broadcast. +// This op broadcasts the `operand` to an output with the given `shape`. +// `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the +// i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th +// dimension of the output. This also requires that the i'th input dimension is +// either 1 or is the same as the output dimension it's broadcasting into. // -// Operand specifies the input to be broadcast. "shape" is expected output -// shape. "broadcast_dimensions" are the dimensions to be broadcasting into. -// Dimension numbers in broadcast_dimensions map to individual dimensions -// of the operand, and specify what dimension of the output shape they -// should be broadcast. -// e.g. -// Say operand = [1, 2], i.e., a 1D tensor with 2 elements. -// and dimension of shape is [2,2]. -// Specifying {1} as brodcast_dimension will generate output -// [1 , 2] -// [1 , 2] -// On the other hand, specifying {0} as broadcast_dimension -// will generate output -// [1 , 1] -// [2 , 2] -XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, +// For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the +// output shape is s32[2,2]: +// - Specifying {1} as brodcast_dimension will generate output +// {{1, 2}, +// {1, 2}} +// - On the other hand, specifying {0} as broadcast_dimension +// will generate output +// {{1 , 1}, +// {2 , 2}} +XlaOp BroadcastInDim(const XlaOp& operand, + const absl::Span out_dim_size, const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on @@ -1728,6 +1460,17 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque = ""); +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = ""); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given @@ -1818,6 +1561,8 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -1842,7 +1587,7 @@ XlaOp CrossReplicaSum(const XlaOp& operand, // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. // -// TODO(b/79737069): Rename this to AllReduce when it's ready to use. +// TODO(b/117564385): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, absl::Span replica_groups = {}, @@ -1980,12 +1725,12 @@ XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // the last dimension is chosen by default. // // If both keys and values are provided: -// * The keys and the values must tensors with the same dimensions. The +// * The keys and all values must be tensors with the same dimensions. The // element types of the tensors may be different. // * The result is a tuple that consists of a sorted tensor of keys (along the -// provided dimension, as above) as the first element, and a tensor with their -// corresponding values as the second element. -XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, +// provided dimension, as above) as the first element, and tensors with their +// corresponding values as the other elements. +XlaOp Sort(const XlaOp& keys, absl::Span values = {}, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. @@ -2119,7 +1864,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, const XlaOp& grad_output, float epsilon, int64 feature_index); +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension); + // Implementation details below this point. +// template XlaOp XlaBuilder::ConstantR0(NativeT value) { diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 7c37ed00cd3dcc214fb0b36c0161d3c39a5bf8c8..b3f5be300d3f15397ad33858a6a9cab5f6029688 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -43,7 +43,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -54,7 +54,7 @@ class XlaBuilderTest : public ::testing::Test { const HloModuleProto& proto = computation.proto(); TF_ASSIGN_OR_RETURN(const auto& config, HloModule::CreateModuleConfigFromProto( - proto, legacy_flags::GetDebugOptionsFromFlags())); + proto, GetDebugOptionsFromFlags())); return HloModule::CreateFromProto(proto, config); } @@ -264,6 +264,26 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { op::Broadcast(op::Reshape(op::Parameter(1))))); } +TEST_F(XlaBuilderTest, BroadcastInDim) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + BroadcastInDim(x, {2, 4, 3}, + /*broadcast_dimensions=*/{0, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Broadcast()); +} + +TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x"); + BroadcastInDim(x, {2, 3, 4}, + /*broadcast_dimensions=*/{0, 1, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Reshape(op::Broadcast()))); +} + TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { XlaBuilder b1("b1"); auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0"); @@ -329,6 +349,15 @@ TEST_F(XlaBuilderTest, CollectivePermute) { EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); } +TEST_F(XlaBuilderTest, GetDimensionSize) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + GetDimensionSize(x, 1); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); @@ -396,5 +425,35 @@ TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) { ::testing::HasSubstr("root operation is not in this computation")); } +TEST_F(XlaBuilderTest, ProtoMatches) { + std::vector computations; + for (int i = 0; i < 2; ++i) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1"); + Add(p0, Add(p1, p0)); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = ConstantR0(&b, 1); + auto two = ConstantR0(&b, 2); + Add(Call(&b, call, {x, y}), Call(&b, call, {one, two})); + computations.push_back(b.Build().ValueOrDie()); + } + auto c0_string = computations[0].proto().SerializeAsString(); + auto c1_string = computations[1].proto().SerializeAsString(); + EXPECT_EQ(c0_string, c1_string); +} + +TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) { + XlaBuilder b(TestName()); + AfterAll(&b, {CreateToken(&b), ConstantR0(&b, 1.0)}); + Status status = b.Build().status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("All operands to AfterAll must be tokens")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd..f317892c12529b2ee8a81788f6bbcae3b3d6489d 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -24,8 +24,8 @@ limitations under the License. namespace xla { StatusOr XlaComputation::GetProgramShape() const { - TF_RET_CHECK(proto_.has_program_shape()); - return proto_.program_shape(); + TF_RET_CHECK(proto_.has_host_program_shape()); + return ProgramShape(proto_.host_program_shape()); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 71598ef8b296a760b0ee818fce0a59aed5cfc6b4..3ccbfb28bd0c5939ee40878e9cc298688882ac62 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc similarity index 93% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc rename to tensorflow/compiler/xla/debug_options_flags.cc index 3ed3afcfcede20fbf5c7d4f004378817febeb4c7..20609cad58d920c0c272899c41efeb99d23cd490 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -13,17 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "absl/strings/str_split.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" namespace xla { -namespace legacy_flags { - namespace { DebugOptions* flag_values; @@ -56,7 +54,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // TODO(jlebar): Disable fastmath once doing so is not a performance // regression. flags->set_xla_cpu_enable_fast_math(true); - flags->set_xla_gpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_min_max(true); flags->set_xla_force_host_platform_device_count(1); } @@ -101,8 +99,8 @@ void AllocateFlags() { [](string comma_separated_values) { auto* extra_options_map = flag_values->mutable_xla_backend_extra_options(); - impl::parse_xla_backend_extra_options(extra_options_map, - comma_separated_values); + parse_xla_backend_extra_options(extra_options_map, + comma_separated_values); return true; }; @@ -111,8 +109,8 @@ void AllocateFlags() { [](string reduce_precision_option_value) { HloReducePrecisionOptions* option_proto = flag_values->add_hlo_reduce_precision_options(); - return impl::parse_xla_reduce_precision_option( - option_proto, reduce_precision_option_value); + return parse_xla_reduce_precision_option(option_proto, + reduce_precision_option_value); }; flag_objects = new std::vector({ @@ -162,11 +160,11 @@ void AllocateFlags() { "Enable unsafe fast-math optimizations in the CPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( - "xla_gpu_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), - "Enable unsafe fast-math optimizations in the GPU compiler; " - "this may produce faster code at the expense of some accuracy."), + "xla_gpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), + flag_values->xla_gpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that does not propagate " + "NaNs."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for( @@ -336,8 +334,14 @@ void AllocateFlags() { "overhead from context switching but we let the user override this " "behavior to help run tests on the host that run models in parallel " "across multiple devices."), + tensorflow::Flag( + "xla_gpu_disable_ptxas_optimizations", + bool_setter_for( + &DebugOptions::set_xla_gpu_disable_ptxas_optimizations), + flag_values->xla_gpu_disable_ptxas_optimizations(), + "In XLA:GPU run ptxas in -O0 (default is -O3)."), }); - ParseFlagsFromEnv(*flag_objects); + ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } } // namespace @@ -353,5 +357,4 @@ xla::DebugOptions GetDebugOptionsFromFlags() { return *flag_values; } -} // namespace legacy_flags } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h similarity index 81% rename from tensorflow/compiler/xla/legacy_flags/debug_options_flags.h rename to tensorflow/compiler/xla/debug_options_flags.h index b53157f59c61cf4e0850e006ad3656f4be63a936..60e59abc2a2e0f1cce3de1afc928f9fe36f75b33 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ #include @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Appends flag definitions for debug options to flag_list. void AppendDebugOptionsFlags(std::vector* flag_list); @@ -32,7 +31,6 @@ void AppendDebugOptionsFlags(std::vector* flag_list); // first. xla::DebugOptions GetDebugOptionsFromFlags(); -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/debug_options_parsers.h similarity index 94% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h rename to tensorflow/compiler/xla/debug_options_parsers.h index ee7eb019c07cf898e48886955b18710146644cac..80aadfd5ece0e768afaf1842d2b6c5b11c288b55 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ #include #include "absl/strings/numbers.h" @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla.pb.h" namespace xla { -namespace legacy_flags { -namespace impl { template void parse_xla_backend_extra_options(T* extra_options_map, @@ -140,8 +138,6 @@ inline bool parse_xla_reduce_precision_option( return true; } -} // namespace impl -} // namespace legacy_flags } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/debug_options_parsers_test.cc similarity index 88% rename from tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc rename to tensorflow/compiler/xla/debug_options_parsers_test.cc index 6f197aec53c7596e84437a03affa9118f22f5a1d..8003c3496d5df9be2ff8a99bc171972c8e090c43 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/debug_options_parsers_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" +#include "tensorflow/compiler/xla/debug_options_parsers.h" #include #include @@ -23,13 +23,12 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace xla { -namespace legacy_flags { // Test that the xla_backend_extra_options flag is parsed correctly. TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { std::unordered_map test_map; string test_string = "aa=bb,cc,dd=,ee=ff=gg"; - impl::parse_xla_backend_extra_options(&test_map, test_string); + parse_xla_backend_extra_options(&test_map, test_string); EXPECT_EQ(test_map.size(), 4); EXPECT_EQ(test_map.at("aa"), "bb"); EXPECT_EQ(test_map.at("cc"), ""); @@ -41,7 +40,7 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -56,7 +55,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { HloReducePrecisionOptions proto; string test_string = "OP_OUTPUTS=5,10:add,dot;"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -71,7 +70,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -84,7 +83,7 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { HloReducePrecisionOptions proto; string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz"; - EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_TRUE(parse_xla_reduce_precision_option(&proto, test_string)); EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); EXPECT_EQ(proto.exponent_bits(), 5); EXPECT_EQ(proto.mantissa_bits(), 10); @@ -96,7 +95,6 @@ TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { diff --git a/tensorflow/compiler/xla/execution_options_util.cc b/tensorflow/compiler/xla/execution_options_util.cc index e83ff7cddd675197c7f6d7018257edb4c25b6228..cf569863bbe1c92bdcafb133d49dcf5ae8890ffe 100644 --- a/tensorflow/compiler/xla/execution_options_util.cc +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" namespace xla { ExecutionOptions CreateDefaultExecutionOptions() { ExecutionOptions execution_options; - *(execution_options.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); + *(execution_options.mutable_debug_options()) = GetDebugOptionsFromFlags(); return execution_options; } diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index fb135f5ceda67ce6c001de15b8f3f084ca164826..1fea816a803bfb75b9721393cef8c4dfc249268d 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as _np # Avoids becoming a part of public Tensorflow API. from tensorflow.compiler.xla import xla_data_pb2 -from tensorflow.compiler.xla.python_api import xla_shape from tensorflow.core.framework import attr_value_pb2 @@ -64,22 +61,18 @@ class Sharding(object): tile_assignment_devices=[core])) @classmethod - def tile(cls, tile_shape, tile_assignment): + def tile(cls, tile_assignment): """Returns a Tiled sharding attribute. This causes an op to be partially computed on multiple cores in the XLA device. Args: - tile_shape: A xla_shape.Shape describing the tile shape that each core - will compute. - The tile shape does not need to be divisible by the tile assignment. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: - TypeError: tile_assignment was not of np.array type or tile_shape was - not of xla_shape.Shape type. + TypeError: tile_assignment was not of np.array type. TODO(jmolloy): This concept is nefarious and is not something we really want to expose to users (especially as the @@ -87,14 +80,11 @@ class Sharding(object): """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') - if not isinstance(tile_shape, xla_shape.Shape): - raise TypeError('Tile shape must be of type xla_shape.Shape') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape.message, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices))) @@ -118,14 +108,8 @@ class Sharding(object): shape = tensor.shape.as_list() if shape[split_dimension] < num_devices: raise ValueError('Split dimension was smaller than the required number ' - 'of splits: shape=%r, dimension=%r, num_devices=%r', - shape, split_dimension, num_devices) - - tile_shape = shape - tile_shape[split_dimension] = int( - math.ceil(tile_shape[split_dimension] / num_devices)) - tile_shape_proto = xla_data_pb2.Shape( - element_type=xla_data_pb2.F32, dimensions=tile_shape) + 'of splits: shape=%r, dimension=%r, num_devices=%r' % + (shape, split_dimension, num_devices)) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices @@ -133,7 +117,6 @@ class Sharding(object): return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, - tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices))) @@ -149,7 +132,6 @@ class Sharding(object): type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto - attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. @@ -194,8 +176,8 @@ def assign_device(tensor, device): return tensor -def tile(tensor, tile_shape, tile_assignment): - Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) +def tile(tensor, tile_assignment): + Sharding.tile(tile_assignment).apply_to_tensor(tensor) return tensor diff --git a/tensorflow/compiler/xla/g3doc/README.md b/tensorflow/compiler/xla/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6643bf0aab3078ff24c86b81de69216355da69a1 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/README.md @@ -0,0 +1,3 @@ +# XLA: Accelerated Linear Algebra + +These are the docs for: https://www.tensorflow.org/xla diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml new file mode 100644 index 0000000000000000000000000000000000000000..267701e9c0e42a21d2cda6238520f6a9692e7e76 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -0,0 +1,35 @@ +upper_tabs: +# Tabs left of dropdown menu +- include: /_upper_tabs_left.yaml +- include: /api_docs/_upper_tabs_api.yaml +# Dropdown menu +- name: Resources + path: /resources + is_default: true + menu: + - include: /resources/_menu_toc.yaml + lower_tabs: + # Subsite tabs + other: + - name: Guide & Tutorials + contents: + - title: XLA overview + path: /xla/overview + - title: Broadcasting semantics + path: /xla/broadcasting + - title: Developing a new backend for XLA + path: /xla/developing_new_backend + - title: Using JIT compilation + path: /xla/jit + - title: Operation semantics + path: /xla/operation_semantics + - title: Shapes and layout + path: /xla/shapes + - title: Using AOT compilation + path: /xla/tfcompile + - heading: Tutorials + - title: XLA compile API + path: /xla/tutorials/xla_compile + status: experimental + +- include: /_upper_tabs_right.yaml diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml new file mode 100644 index 0000000000000000000000000000000000000000..858de427119bfcfa82d0b1158776bf269129fd92 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/_index.yaml @@ -0,0 +1,35 @@ +book_path: /xla/_book.yaml +project_path: /xla/_project.yaml +description: +landing_page: + custom_css_path: /site-assets/css/style.css + rows: + - heading: XLA is a compiler that optimizes TensorFlow computations. + items: + - classname: devsite-landing-row-50 + description: > + XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear + algebra that optimizes TensorFlow computations. The results are + improvements in speed, memory usage, and portability on server and mobile + platforms. The XLA framework is experimental and in active development. + For details, read the XLA guide. + + - classname: devsite-landing-row-cards + items: + - heading: XLA - TensorFlow, compiled + image_path: /resources/images/tf-logo-card-16x9.png + path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html + buttons: + - label: Read on Google Developers blog + path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html + - heading: XLA at the Dev Summit + youtube_id: kAOanJczHA0 + buttons: + - label: Watch the video + path: https://www.youtube.com/watch?v=kAOanJczHA0 + - heading: XLA on GitHub + image_path: /resources/images/github-card-16x9.png + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla + buttons: + - label: View on GitHub + path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla diff --git a/tensorflow/compiler/xla/g3doc/_project.yaml b/tensorflow/compiler/xla/g3doc/_project.yaml new file mode 100644 index 0000000000000000000000000000000000000000..33d8bdb27a664d9e282d1d65c007ebf5838b196a --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/_project.yaml @@ -0,0 +1,10 @@ +name: XLA +breadcrumb_name: XLA +home_url: /xla/ +parent_project_metadata_path: /_project.yaml +description: > + XLA is a compiler-based linear algebra execution engine. +use_site_branding: true +hide_from_products_list: true +content_license: cc3-apache2 +buganizer_id: 171704 diff --git a/tensorflow/compiler/xla/g3doc/broadcasting.md b/tensorflow/compiler/xla/g3doc/broadcasting.md new file mode 100644 index 0000000000000000000000000000000000000000..2870869a2cef13a9105b9dc9fa4d657834288f86 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/broadcasting.md @@ -0,0 +1,204 @@ +# Broadcasting semantics + +This document describes how the broadcasting semantics in XLA work. + +## What is broadcasting? + +Broadcasting is the process of making arrays with different shapes have +compatible shapes for arithmetic operations. The terminology is borrowed from +Numpy +[broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + +Broadcasting may be required for operations between multi-dimensional arrays of +different ranks, or between multi-dimensional arrays with different but +compatible shapes. Consider the addition `X+v` where `X` is a matrix (an array +of rank 2) and `v` is a vector (an array of rank 1). To perform element-wise +addition, XLA needs to "broadcast" the vector `v` to the same rank as the +matrix `X`, by replicating `v` a certain number of times. The vector's length +has to match at least one of the dimensions of the matrix. + +For example: + + |1 2 3| + |7 8 9| + |4 5 6| + +The matrix's dimensions are (2,3), the vector's are (3). The vector is broadcast +by replicating it over rows to get: + + |1 2 3| + |7 8 9| = |8 10 12| + |4 5 6| |7 8 9| |11 13 15| + +In Numpy, this is called +[broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + +## Principles + +The XLA language is as strict and explicit as possible, avoiding implicit and +"magical" features. Such features may make some computations slightly easier to +define, at the cost of more assumptions baked into user code that will be +difficult to change in the long term. If necessary, implicit and magical +features can be added in client-level wrappers. + +In regards to broadcasting, explicit broadcasting specifications on operations +between arrays of different ranks is required. This is different from Numpy, +which infers the specification when possible. + +## Broadcasting a lower-rank array onto a higher-rank array + +*Scalars* can always be broadcast over arrays without an explicit specification +of broadcasting dimensions. An element-wise binary operation between a scalar +and an array means applying the operation with the scalar for each element in +the array. For example, adding a scalar to a matrix means producing a matrix +each element of which is a sum of the scalar with the corresponding input +matrix's element. + + |1 2 3| + 7 = |8 9 10| + |4 5 6| |11 12 13| + +Most broadcasting needs can be captured by using a tuple of dimensions on a +binary operation. When the inputs to the operation have different ranks, this +broadcasting tuple specifies which dimension(s) in the **higher-rank** array to +match with the **lower-rank** array. + +Consider the previous example, instead of adding a scalar to a (2,3) matrix, add +a vector of dimension (3) to a matrix of dimensions (2,3). *Without specifying +broadcasting, this operation is invalid.* To correctly request matrix-vector +addition, specify the broadcasting dimension to be (1), meaning the vector's +dimension is matched to dimension 1 of the matrix. In 2D, if dimension 0 is +considered as rows and dimension 1 as columns, this means that each element of +the vector becomes a column of a size matching the number of rows in the matrix: + + |7 8 9| ==> |7 8 9| + |7 8 9| + +As a more complex example, consider adding a 3-element vector (dimension (3)) to +a 3x3 matrix (dimensions (3,3)). There are two ways broadcasting can happen for +this example: + +(1) A broadcasting dimension of 1 can be used. Each vector element becomes a +column and the vector is duplicated for each row in the matrix. + + |7 8 9| ==> |7 8 9| + |7 8 9| + |7 8 9| + +(2) A broadcasting dimension of 0 can be used. Each vector element becomes a row +and the vector is duplicated for each column in the matrix. + + |7| ==> |7 7 7| + |8| |8 8 8| + |9| |9 9 9| + +> Note: when adding a 2x3 matrix to a 3-element vector, a broadcasting dimension +> of 0 is invalid. + +The broadcasting dimensions can be a tuple that describes how a smaller rank +shape is broadcast into a larger rank shape. For example, given a 2x3x4 cuboid +and a 3x4 matrix, a broadcasting tuple (1,2) means matching the matrix to +dimensions 1 and 2 of the cuboid. + +This type of broadcast is used in the binary ops in `XlaBuilder`, if the +`broadcast_dimensions` argument is given. For example, see +[XlaBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.cc). +In the XLA source code, this type of broadcasting is sometimes called "InDim" +broadcasting. + +### Formal definition + +The broadcasting attribute allows matching a lower-rank array to a higher-rank +array, by specifying which dimensions of the higher-rank array to match. For +example, for an array with dimensions MxNxPxQ, a vector with dimension T can be +matched as follows: + + MxNxPxQ + + dim 3: T + dim 2: T + dim 1: T + dim 0: T + +In each case, T has to be equal to the matching dimension of the higher-rank +array. The vector's values are then broadcast from the matched dimension to all +the other dimensions. + +To match a TxV matrix onto the MxNxPxQ array, a pair of broadcasting dimensions +are used: + + MxNxPxQ + dim 2,3: T V + dim 1,2: T V + dim 0,3: T V + etc... + +The order of dimensions in the broadcasting tuple has to be the order in which +the lower-rank array's dimensions are expected to match the higher-rank array's +dimensions. The first element in the tuple says which dimension in the +higher-rank array has to match dimension 0 in the lower-rank array. The second +element for dimension 1, and so on. The order of broadcast dimensions has to be +strictly increasing. For example, in the previous example it is illegal to match +V to N and T to P; it is also illegal to match V to both P and N. + +## Broadcasting similar-rank arrays with degenerate dimensions + +A related broadcasting problem is broadcasting two arrays that have the same +rank but different dimension sizes. Similarly to Numpy's rules, this is only +possible when the arrays are *compatible*. Two arrays are compatible when all +their dimensions are compatible. Two dimensions are compatible if: + +* They are equal, or +* One of them is 1 (a "degenerate" dimension) + +When two compatible arrays are encountered, the result shape has the maximum +among the two inputs at every dimension index. + +Examples: + +1. (2,1) and (2,3) broadcast to (2,3). +2. (1,2,5) and (7,2,5) broadcast to (7,2,5) +3. (7,2,5) and (7,1,5) broadcast to (7,2,5) +4. (7,2,5) and (7,2,6) are incompatible and cannot be broadcast. + +A special case arises, and is also supported, where each of the input arrays has +a degenerate dimension at a different index. In this case, the result is an +"outer operation": (2,1) and (1,3) broadcast to (2,3). For more examples, +consult the +[Numpy documentation on broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + +## Broadcast composition + +Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting +using degenerate dimensions can both be performed in the same binary operation. +For example, a vector of size 4 and an matrix of size 1x2 can be added together +using broadcast dimensions value of (0): + + |1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector. + +First the vector is broadcast up to rank 2 (matrix) using the broadcast +dimensions. The single value (0) in the broadcast dimensions indicates that +dimension zero of the vector matches to dimension zero of the matrix. This +produces an matrix of size 4xM where the value M is chosen to match the +corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is +produced: + + |1 1| + [5 6] + |2 2| + |3 3| + |4 4| + +Then "degenerate dimension broadcasting" broadcasts dimension zero of the 1x2 +matrix to match the corresponding dimension size of the right hand side: + + |1 1| + |5 6| |6 7| + |2 2| + |5 6| = |7 8| + |3 3| + |5 6| |8 9| + |4 4| + |5 6| |9 10| + +A more complicated example is a matrix of size 1x2 added to an array of size +4x3x1 using broadcast dimensions of (1, 2). First the 1x2 matrix is broadcast up +to rank 3 using the broadcast dimensions to produces an intermediate Mx1x2 array +where the dimension size M is determined by the size of the larger operand (the +4x3x1 array) producing a 4x1x2 intermediate array. The M is at dimension 0 +(left-most dimension) because the dimensions 1 and 2 are mapped to the +dimensions of the original 1x2 matrix as the broadcast dimension are (1, 2). +This intermediate array can be added to the 4x3x1 matrix using broadcasting of +degenerate dimensions to produce a 4x3x2 array result. diff --git a/tensorflow/compiler/xla/g3doc/developing_new_backend.md b/tensorflow/compiler/xla/g3doc/developing_new_backend.md new file mode 100644 index 0000000000000000000000000000000000000000..5ede7f523131cf715575074b8e27487be5ea77c6 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/developing_new_backend.md @@ -0,0 +1,76 @@ +# Developing a new backend for XLA + +This preliminary guide is for early adopters that want to easily retarget +TensorFlow to their hardware in an efficient manner. The guide is not +step-by-step and assumes knowledge of [LLVM](http://llvm.org), +[Bazel](https://bazel.build/), and TensorFlow. + +XLA provides an abstract interface that a new architecture or accelerator can +implement to create a backend to run TensorFlow graphs. Retargeting XLA should +be significantly simpler and scalable than implementing every existing +TensorFlow Op for new hardware. + +Most implementations will fall into one of the following scenarios: + +1. Existing CPU architecture not yet officially supported by XLA, with or + without an existing [LLVM](http://llvm.org) backend. +2. Non-CPU-like hardware with an existing LLVM backend. +3. Non-CPU-like hardware without an existing LLVM backend. + +> Note: An LLVM backend can mean either one of the officially released LLVM +> backends or a custom LLVM backend developed in-house. + +## Scenario 1: Existing CPU architecture not yet officially supported by XLA + +In this scenario, start by looking at the existing +[XLA CPU backend](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/cpu/). +XLA makes it easy to retarget TensorFlow to different CPUs by using LLVM, since +the main difference between XLA backends for CPUs is the code generated by LLVM. +Google tests XLA for x64 and ARM64 architectures. + +If the hardware vendor has an LLVM backend for their hardware, it is simple to +link the backend with the LLVM built with XLA. In JIT mode, the XLA CPU backend +emits code for the host CPU. For ahead-of-time compilation, +[`xla::AotCompilationOptions`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h) +can provide an LLVM triple to configure the target architecture. + +If there is no existing LLVM backend but another kind of code generator exists, +it should be possible to reuse most of the existing CPU backend. + +## Scenario 2: Non-CPU-like hardware with an existing LLVM backend + +It is possible to model a new +[`xla::Compiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h) +implementation on the existing +[`xla::CPUCompiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc) +and [`xla::GPUCompiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc) +classes, since these already emit LLVM IR. Depending on the nature of the +hardware, it is possible that many of the LLVM IR generation aspects will have +to be changed, but a lot of code can be shared with the existing backends. + +A good example to follow is the +[GPU backend](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/gpu/) +of XLA. The GPU backend targets a non-CPU-like ISA, and therefore some aspects +of its code generation are unique to the GPU domain. Other kinds of hardware, +e.g. DSPs like Hexagon (which has an upstream LLVM backend), can reuse parts of +the LLVM IR emission logic, but other parts will be unique. + +## Scenario 3: Non-CPU-like hardware without an existing LLVM backend + +If it is not possible to utilize LLVM, then the best option is to implement a +new backend for XLA for the desired hardware. This option requires the most +effort. The classes that need to be implemented are as follows: + +* [`StreamExecutor`](https://www.tensorflow.org/code/tensorflow/stream_executor/stream_executor.h): + For many devices not all methods of `StreamExecutor` are needed. See + existing `StreamExecutor` implementations for details. +* [`xla::Compiler`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/compiler.h): + This class encapsulates the compilation of an HLO computation into an + `xla::Executable`. +* [`xla::Executable`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/executable.h): + This class is used to launch a compiled computation on the platform. +* [`xla::TransferManager`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/transfer_manager.h): + This class enables backends to provide platform-specific mechanisms for + constructing XLA literal data from given device memory handles. In other + words, it helps encapsulate the transfer of data from the host to the device + and back. diff --git a/tensorflow/compiler/xla/g3doc/images/how-does-xla-work.png b/tensorflow/compiler/xla/g3doc/images/how-does-xla-work.png new file mode 100644 index 0000000000000000000000000000000000000000..15f86c3221d3637f2087a2db9f4cb008fe2690fa Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/how-does-xla-work.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png b/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png new file mode 100644 index 0000000000000000000000000000000000000000..4e2dc091fee1d13ae659988b1a68505e9ff77b27 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png b/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png new file mode 100644 index 0000000000000000000000000000000000000000..39d7c90c4fc3d707df062562fcf9ebdc37344af0 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png new file mode 100644 index 0000000000000000000000000000000000000000..a38f636983b527b678f17d3b0c92646ac1485f86 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png new file mode 100644 index 0000000000000000000000000000000000000000..285c3a96d5aa33605cab2486522a5e815901a2fc Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png new file mode 100644 index 0000000000000000000000000000000000000000..488fc2c2f1009706b7e2c5ded154f47e2b7f4bcb Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png new file mode 100644 index 0000000000000000000000000000000000000000..d0df38cf18197f89224cc0f5ff643dd537d03fcc Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_2d_matrix.png b/tensorflow/compiler/xla/g3doc/images/ops_2d_matrix.png new file mode 100644 index 0000000000000000000000000000000000000000..4846d1700607ced60dd3b8038996894d4dd0f8af Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_2d_matrix.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_alltoall.png b/tensorflow/compiler/xla/g3doc/images/ops_alltoall.png new file mode 100644 index 0000000000000000000000000000000000000000..c8150bda5bd6fb5723832a5e42e71c12cee3d399 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_alltoall.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_concatenate.png b/tensorflow/compiler/xla/g3doc/images/ops_concatenate.png new file mode 100644 index 0000000000000000000000000000000000000000..26ded3d88c07205dd6eceef2d2ee151b4e390977 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_concatenate.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_pad.png b/tensorflow/compiler/xla/g3doc/images/ops_pad.png new file mode 100644 index 0000000000000000000000000000000000000000..dc1948a627a88721d44bd22027ab75540f61feda Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_pad.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_2d_matrix.png b/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_2d_matrix.png new file mode 100644 index 0000000000000000000000000000000000000000..c2ff037ab5c6ad7b2b2157339f189cff3b16df09 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_2d_matrix.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_3d_matrix.png b/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_3d_matrix.png new file mode 100644 index 0000000000000000000000000000000000000000..ebeeca093b2dda7fc5871e53302bce0e73e670be Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_reduce_from_3d_matrix.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_reduce_window.png b/tensorflow/compiler/xla/g3doc/images/ops_reduce_window.png new file mode 100644 index 0000000000000000000000000000000000000000..e9cdc3d148ab4ebb46bef8af84724134eae75d55 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_reduce_window.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_reduce_window_stride.png b/tensorflow/compiler/xla/g3doc/images/ops_reduce_window_stride.png new file mode 100644 index 0000000000000000000000000000000000000000..f1ef5270dbac9f4ca1eb884e5cb27fd57a02ba8e Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_reduce_window_stride.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_scatter_to_selected_window_element.png b/tensorflow/compiler/xla/g3doc/images/ops_scatter_to_selected_window_element.png new file mode 100644 index 0000000000000000000000000000000000000000..4a82afaefab42d837a46178ba9aef3a1b6ddc434 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_scatter_to_selected_window_element.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_while.png b/tensorflow/compiler/xla/g3doc/images/ops_while.png new file mode 100644 index 0000000000000000000000000000000000000000..da32b553eb0226bfb1122c236dfefe151758b9fa Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/ops_while.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_0.svg b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_0.svg new file mode 100644 index 0000000000000000000000000000000000000000..7d324aa35bd92aeef7bc2987eaf346f1c3aa0966 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_0.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_1.svg b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_1.svg new file mode 100644 index 0000000000000000000000000000000000000000..f460b923f0efa5594a25251ba308f0fe4b9bf786 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_2.svg b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_2.svg new file mode 100644 index 0000000000000000000000000000000000000000..d9c35e972d152c63df44d3c9be65ec3a840d5544 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/images/ops_xla_gather_2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/xla/g3doc/images/send_recv_order.png b/tensorflow/compiler/xla/g3doc/images/send_recv_order.png new file mode 100644 index 0000000000000000000000000000000000000000..721200e3cb0af984f58cb8594607e5c0a39ddd18 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/send_recv_order.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/send_recv_schedule.png b/tensorflow/compiler/xla/g3doc/images/send_recv_schedule.png new file mode 100644 index 0000000000000000000000000000000000000000..c830f987ab9b7e53730555d5734ce37bd1854211 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/send_recv_schedule.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png new file mode 100644 index 0000000000000000000000000000000000000000..00cefe4c7806c1c09dd51499375e720bfb0baac6 Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure1.png differ diff --git a/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png new file mode 100644 index 0000000000000000000000000000000000000000..6439c6e40272ae6b2954e9d7f3de2df470a2b36d Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/xla_array_layout_figure2.png differ diff --git a/tensorflow/compiler/xla/xlalogo.png b/tensorflow/compiler/xla/g3doc/images/xlalogo.png similarity index 100% rename from tensorflow/compiler/xla/xlalogo.png rename to tensorflow/compiler/xla/g3doc/images/xlalogo.png diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md new file mode 100644 index 0000000000000000000000000000000000000000..85fa16ccc7f48a3dce840564e79097c9e136767f --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/jit.md @@ -0,0 +1,180 @@ +# Using JIT Compilation + +> Note: TensorFlow must be compiled from source to include XLA. + +## Why use just-in-time (JIT) compilation? + +The TensorFlow/XLA JIT compiler compiles and runs parts of TensorFlow graphs via +XLA. The benefit of this over the standard TensorFlow implementation is that XLA +can fuse multiple operators (kernel fusion) into a small number of compiled +kernels. Fusing operators can reduce memory bandwidth requirements and improve +performance compared to executing operators one-at-a-time, as the TensorFlow +executor does. + +## Running TensorFlow graphs via XLA + +There are two ways to run TensorFlow computations via XLA, either by +JIT-compiling operators placed on a CPU or GPU device, or by placing operators +on the `XLA_CPU` or `XLA_GPU` TensorFlow devices. Placing operators directly on +a TensorFlow XLA device forces the operator to run on that device and is mainly +used for testing. + +> Note: The XLA CPU backend supports intra-op parallelism (i.e. it can shard a +> single operation across multiple cores) but it does not support inter-op +> parallelism (i.e. it cannot execute independent operations concurrently across +> multiple cores). The XLA GPU backend is competitive with the standard +> TensorFlow implementation, sometimes faster, sometimes slower. + +### Turning on JIT compilation + +JIT compilation can be turned on at the session level or manually for select +operations. Both of these approaches are zero-copy --- data does not need to be +copied when passing data between a compiled XLA kernel and a TensorFlow operator +placed on the same device. + +#### Session + +Turning on JIT compilation at the session level will result in all possible +operators being greedily compiled into XLA computations. Each XLA computation +will be compiled into one or more kernels for the underlying device. + +Subject to a few constraints, if there are two adjacent operators in the graph +that both have XLA implementations, then they will be compiled into a single XLA +computation. + +JIT compilation is turned on at the session level by setting the +`global_jit_level` config to `tf.OptimizerOptions.ON_1` and passing the config +during session initialization. + +```python +# Config to turn on JIT compilation +config = tf.ConfigProto() +config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 + +sess = tf.Session(config=config) +``` + +> Note: Turning on JIT at the session level will not result in operations being +> compiled for the CPU. JIT compilation for CPU operations must be done via +> the manual method documented below. + +#### Manual with experimental_jit_scope() + +JIT compilation can also be turned on manually for one or more operators. This +is done by tagging the operators to compile with the attribute +`_XlaCompile=true`. The simplest way to do this is via the +`tf.contrib.compiler.jit.experimental_jit_scope()` scope defined in +[`tensorflow/contrib/compiler/jit.py`](https://www.tensorflow.org/code/tensorflow/contrib/compiler/jit.py). +Example usage: + +```python + jit_scope = tf.contrib.compiler.jit.experimental_jit_scope + + x = tf.placeholder(np.float32) + with jit_scope(): + y = tf.add(x, x) # The "add" will be compiled with XLA. +``` + +The `_XlaCompile` attribute is currently supported on a best-effort basis. If an +operator cannot be compiled, TensorFlow will silently fall back to the normal +implementation. + +#### Manual with xla.compile() + +Unlike experimental_jit_scope() which silently falls back to normal Tensorflow +on uncompilable operator, xla.compile() returns an explicit error. This is +useful if you want more predictable behaviors from XLA compilation. + +Please see +[xla.compile() tutorial Colab](./tutorials/xla_compile.ipynb) +for how to use it. + +### Placing operators on XLA devices + +Another way to run computations via XLA is to place an operator on a specific +XLA device. This method is normally only used for testing. Valid targets are +`XLA_CPU` or `XLA_GPU`. + +```python +with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"): + output = tf.add(input1, input2) +``` + +Unlike JIT compilation on the standard CPU and GPU devices, these devices make a +copy of data when it is transferred on and off the device. The extra copy makes +it expensive to mix XLA and TensorFlow operators in the same graph. + +## Tutorial + +This tutorial covers training a simple version of MNIST softmax with JIT turned +on. Currently JIT at the session level, which is what is used for the tutorial, +only supports GPU. + +Before starting the tutorial verify that the LD_LIBRARY environment variable or +ldconfig contains `$CUDA_ROOT/extras/CUPTI/lib64`, which contains libraries for +the CUDA Profiling Tools Interface +[(CUPTI)](http://docs.nvidia.com/cuda/cupti/index.html). TensorFlow uses CUPTI +to pull tracing information from the GPU. + +### Step #1: Prepare sample script + +Download or move +[mnist_softmax_xla.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py) +into a folder outside of the TensorFlow source tree. + +### Step #2: Run without XLA + +Execute the python script to train the model without XLA. + +```shell +python mnist_softmax_xla.py --xla='' +``` + +Using the Chrome Trace Event Profiler (browse to chrome://tracing), +open the timeline file created when the script finishes: `timeline.ctf.json`. +The rendered timeline should look similar to the picture below with multiple +green boxes labeled `MatMul`, possibly across multiple CPUs. +
+ +
+ +### Step #3 Run with XLA + +Execute the python script to train the model with XLA and turn on a debugging +feature of XLA via an environmental variable that outputs the XLA graph. + +```shell +XLA_FLAGS="--xla_hlo_graph_path=/tmp --xla_generate_hlo_graph=.*" python mnist_softmax_xla.py +``` + +Open the timeline file created (`timeline.ctf.json`). The rendered timeline +should look similar to the picture below with one long bar labeled `XlaLaunch`. +
+ +
+ +To understand what is happening in `XlaLaunch`, look at the console output for +statements similar to the following: + +```shell +computation cluster_0[_XlaCompiledKernel=true,_XlaNumConstantArgs=1].v82 [CPU: +pipeline start, before inline]: /tmp/hlo_graph_0.dot + +``` + +The console statements point to the location of `hlo_graph_xx.dot` files that +contain information about the graph created by XLA. The process that XLA takes +to fuse Ops is visible by starting at `hlo_graph_0.dot` and viewing each diagram +in succession. + +To Render the .dot file into a png, install +[GraphViz](https://www.graphviz.org/download/) and run: + +```shell +dot -Tpng hlo_graph_80.dot -o hlo_graph_80.png +``` + +The result will look like the following: +
+ +
diff --git a/tensorflow/compiler/xla/g3doc/layout_with_tiling.md b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md new file mode 100644 index 0000000000000000000000000000000000000000..5e990851af7495ebd4417e44f1d955fcc14dadf1 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/layout_with_tiling.md @@ -0,0 +1,159 @@ +# Tiled layout + +*Note: This doc describes how tiled layout is intended to work. Tiling is being +implemented, but this is an early effort and it is currently not even guaranteed +to get an Unimplemented error if one tries to use tiling - it may be just +silently ignored.* + +
![](images/xla_array_layout_figure1.png) + +Figure 1
+ +Figure 1 shows how an array F32[3,5] is laid out in memory with 2x2 tiling. A +shape with this layout is written as F32[3,5]{1,0:(2,2)}, where 1,0 relates to +the physical order of dimensions (minor_to_major field in Layout) while (2,2) +after the colon indicates tiling of the physical dimensions by a 2x2 tile. + +Intuitively tiles are laid out to cover the shape and then within each tile, +elements are then laid out without tiling, as in the example above, where the +right part of the example shows the layout in memory, including the white +padding elements that are added in order to have complete 2x2 tiles even though +the original array bounds are not even. + +The extra elements in the padding are not required to contain any particular +value. + +## Linear index formulas for tiling given a shape and a tile + +Without tiling, an element e=(en, en-1, ... , +e1) in an array with array bounds d=(dn, dn-1, +... , d1) (d1 is the most minor dimension) is laid out by major to +minor order at position: + +   linear_index(e, d) \ += linear_index((en, en-1, ... , e1), +(dn, dn-1, ... , d1)) \ += endn-1...d1 + +en-1dn-2...d1 + ... + e1 + +For simplicity of notation in this document we assume a tile has the same number +of dimensions as the array. In XLA's implementation of tiling, this is +generalized to tilings with fewer dimensions by leaving the initial most-major +dimensions unchanged and applying the tiling only to the most minor dimensions, +so that the tiling that is specified mentions a suffix of the physical +dimensions of the shape being tiled. + +When tiling of size (tn, tn-1, ... , t1) is +used, an element in the array with indices (en, en-1, ... +, e1) is mapped to this position in the final layout: + +   linear_index_with_tile(e, d, t) \ += linear_index((⌊e/t⌋, e mod t), (⌈d/t⌉, t))     (arithmetic is +elementwise, (a,b) is concatenation) \ += linear_index((⌊en/tn⌋, ... , +⌊e1/t1⌋, en mod tn, ... , +e1 mod t1), (⌈dn/tn⌉, ... , +⌈d1/t1⌉, tn, tn-1, ... , +t1)) \ += linear_index((⌊en/tn⌋, ... , +⌊e1/t1⌋), (⌈dn/tn⌉, ... , +⌈d1/t1⌉))∙tntn-1...t1 + +linear_index((en mod tn, ... , e1 mod +t1), (tn, tn-1, ... , t1)) + +The layout can be thought of as having two parts: +(⌊en/tn⌋, ... , ⌊e1/t1⌋), which +corresponds to a tile index in an array of tiles of size +(⌈dn/tn⌉, ... , ⌈d1/t1⌉), and +(en mod tn, ... , e1 mod t1), which +corresponds to a within-tile index. The ceil function appears in +⌈di/ti⌉ because if tiles overrun the bounds of the larger +array, padding is inserted as in Figure 1. Both the tiles and elements within +tiles are laid out recursively without tiling. + +For the example in Figure 1, element (2,3) has tile index (1,1), and within-tile +index (0,1), for a combined coordinate vector of (1, 1, 0, 1). The tile indices +have bounds (2, 3) and the tile itself is (2, 2) for a combined vector of (2, 3, +2, 2). The linear index with tile for the element with index (2, 3) in the +logical shape is then + +   linear_index_with_tile((2,3), (3,5), (2,2)) \ += linear_index((1,1,0,1), (2,3,2,2)) \ += linear_index((1,1), (2,3)) ∙ 2 ∙ 2 + linear_index((0,1), (2,2)) \ += (1 ∙ 3 + 1) ∙ 2 ∙ 2 + (0 ∙ 2 + 1) \ += 17. + +# Tiling as pad-reshape-transpose + +Tiling-based layout operates as follows: \ +Consider an array of dimensions (dn, dn-1, ... , d1) (d1 +is the most minor dimension). When it’s laid out with tiling of size +(tn, tn-1, ... , t1) (t1 is the most +minor dimension), that tiling can be described in terms of pad-reshape-transpose +in the following way. + +1. The array is padded to (⌈dn/tn⌉∙tn, ... , + ⌈d1/t1⌉∙t1). +2. Each dimension i is broken into (⌈di/ti⌉, + ti), i.e. the array is reshaped to \ +     (⌈dn/tn⌉, tn, ... , + ⌈d1/t1⌉, t1). \ + There is no physical layout change in this reshape by itself, so this + reshape is a bitcast. If one is not explicitly thinking of a tiling, this + reshape could express any shape with the same number of elements as the + padded shape - the example here is of how to express a tile in this way. +3. A transpose happens by moving tn, ... , t1 to the most + minor dimensions while keeping their relative order, so that the order of + dimensions from most major to most minor becomes \ +     (⌈dn/tn⌉, ... , + ⌈d1/t1⌉, tn, ... , t1). + +The final shape has the prefix \ +    (⌈dn/tn⌉, ... , +⌈d1/t1⌉), which describes the number of tiles in each +dimension. An element in the array (en, ... , e1) is +mapped to this element in the final shape: \ +    (⌊en/tn⌋, ... , +⌊e0/t0⌋, en mod tn, ... , +e1 mod t1). It is easy to see that the linear index of the +element follows the formula above as expected. + +# Repeated tiling + +XLA's tiling becomes even more flexible by applying it repeatedly. + +
![](images/xla_array_layout_figure2.png) + +Figure 2
+ +Figure 2 shows how an array of size 4x8 is tiled by two levels of tiling (first +2x4 then 2x1). We represent this repeated tiling as (2,4)(2,1). Each color +indicates a 2x4 tile and each red border box is a 2x1 tile. The numbers +indicates the linear index in memory of that element in the tiled format. This +format matches the format used for BF16 on TPU, except that the initial tile is +bigger, namely the tiling is (8,128)(2,1), where the purpose of the second +tiling by 2x1 is to collect together two 16 bit values to form one 32 bit value +in a way that aligns with the architecture of a TPU. + +Note that a second or later tile can refer to both the minor within-tile +dimensions, which just rearranges data within the tile, as in this example with +(8,128)(2,1), but can also refer to the major cross-tile dimensions from the +prior tiling. + +# Combining dimensions using tiles + +XLA's tiling also supports combining dimensions. For example, it can combine +dimensions in F32[2,7,8,11,10]{4,3,2,1,0} into F32[112,110]{1,0} first before +tiling it with (2,3). The tile used is (∗,∗,2,∗,3). Here an +asterisk in a tile implies taking that dimension and combining it with the next +more minor dimension. Multiple adjacent dimensions can be subsumed together into +one dimension. A subsumed dimension is represented by a tile value of -1 in that +dimension of the tile, which is not otherwise valid in a tile as a dimension +size. + +More precisely, if dimension i of the shape is eliminated via an asterisk in the +tile, then before the prior definition of tiling is applied, that dimension is +removed from both the shape being tiled and the tile vector, and what was +dimension i-1 of the shape has its array bound increased from di-1 to +didi-1. This step is repeated for each asterisk in the +tile vector. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md similarity index 88% rename from tensorflow/docs_src/performance/xla/operation_semantics.md rename to tensorflow/compiler/xla/g3doc/operation_semantics.md index 96d269bec4d59bd7eb23e1964bf7d208996aabde..d888b1f23f36f33ef94ef0e22374e0c796e47a89 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -13,6 +13,22 @@ arbitrary-dimensional array. For convenience, special cases have more specific and familiar names; for example a *vector* is a 1-dimensional array and a *matrix* is a 2-dimensional array. +## AfterAll + +See also +[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +AfterAll takes a variadic number of tokens and produces a single token. Tokens +are primitive types which can be threaded between side-effecting operations to +enforce ordering. `AfterAll` can be used as a join of tokens for ordering a +operation after a set operations. + + `AfterAll(operands)` + +Arguments | Type | Semantics +---------- | ------- | ------------------------- +`operands` | `XlaOp` | variadic number of tokens + ## AllToAll See also @@ -77,7 +93,7 @@ AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); ```
- +
In this example, there are 4 cores participating the Alltoall. On each core, the @@ -119,8 +135,8 @@ respect to `operand`, `offset` and `scale` across all the other dimensions. The `feature_index` must be a valid index for the feature dimension in `operand`. The three gradients are defined by the following formulas (assuming a -4-dimensional tensor as `operand` and with feature dimension index \\(l\\), -batch size `m` and spatial sizes `w` and `h`): +4-dimensional array as `operand` and with feature dimension index $$l$$, batch +size `m` and spatial sizes `w` and `h`): \\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h @@ -402,6 +418,33 @@ then v12 == f32[8x3] {{10, 11, 12}, ``` +## CollectivePermute + +See also +[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +CollectivePermute is a collective operation that sends and receives data cross +replicas. + + `CollectivePermute(operand, source_target_pairs)` + +| Arguments | Type | Semantics | +| --------------------- | ----------------------- | -------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `source_target_pairs` | `` vector | A list of | +: : : (source_replica_id, : +: : : target_replica_id) pairs. : +: : : For each pair, the operand : +: : : is sent from source : +: : : replica to target replica. : + +Note that there are the following restrictions on the `source_target_pair`: + +- Any two pairs should not have the same target replica id, and they should + not have the same source replica id. +- If a replica id is not a target in any pair, then the output on that replica + is a tensor consists of 0(s) with the same shape as the input. + ## Concatenate See also @@ -455,7 +498,7 @@ Concat({a, b}, 0) Diagram:
- +
## Conditional @@ -1028,7 +1071,7 @@ Arguments | Type | Semantics `rhs` | `XlaOp` | right-hand-side operand: array of type T The arguments' shapes have to be either similar or compatible. See the -[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to +[broadcasting](broadcasting.md) documentation about what it means for shapes to be compatible. The result of an operation has a shape which is the result of broadcasting the two input arrays. In this variant, operations between arrays of different ranks are *not* supported, unless one of the operands is a scalar. @@ -1056,7 +1099,7 @@ the dimensions of the higher-rank shape. The unmapped dimensions of the expanded shape are filled with dimensions of size one. Degenerate-dimension broadcasting then broadcasts the shapes along these degenerate dimensions to equalize the shapes of both operands. The semantics are described in detail on the -[broadcasting page](../../performance/xla/broadcasting.md). +[broadcasting page](broadcasting.md). ## Element-wise comparison operations @@ -1079,7 +1122,7 @@ Arguments | Type | Semantics `rhs` | `XlaOp` | right-hand-side operand: array of type T The arguments' shapes have to be either similar or compatible. See the -[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to +[broadcasting](broadcasting.md) documentation about what it means for shapes to be compatible. The result of an operation has a shape which is the result of broadcasting the two input arrays with the element type `PRED`. In this variant, operations between arrays of different ranks are *not* supported, unless one of @@ -1096,7 +1139,7 @@ matrix to a vector). The additional `broadcast_dimensions` operand is a slice of integers specifying the dimensions to use for broadcasting the operands. The semantics are described -in detail on the [broadcasting page](../../performance/xla/broadcasting.md). +in detail on the [broadcasting page](broadcasting.md). ## Element-wise unary functions @@ -1152,29 +1195,32 @@ For a more intuitive description, see the "Informal Description" section below. `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` -|Arguments | Type | Semantics | -|----------------- | ----------------------- | --------------------------------| -|`operand` | `XlaOp` | The array we’re gathering | -: : : from. : -|`start_indices` | `XlaOp` | Array containing the starting | -: : : indices of the slices we gather.: -|`index_vector_dim` | `int64` | The dimension in | -: : : `start_indices` that "contains" : -: : : the starting indices. See : -: : : below for a detailed : -: : : description. : -|`offset_dims` | `ArraySlice` | The set of dimensions in the : -: : : output shape that offset into a : -: : : array sliced from operand. : -|`slice_sizes` | `ArraySlice` | `slice_sizes[i]` is the bounds | -: : : for the slice on dimension `i`.: -|`collapsed_slice_dims` | `ArraySlice` | The set of dimensions in each : -| : | slice that are collapsed away. : -| : | These dimensions must have size: -| : | 1. | -|`start_index_map` | `ArraySlice` | A map that describes how to map| -: : : indices in `start_indices` to : -: : : to legal indices into operand. : +| Arguments | Type | Semantics | +| ---------------------- | ------------------- | ----------------------------- | +| `operand` | `XlaOp` | The array we’re gathering | +: : : from. : +| `start_indices` | `XlaOp` | Array containing the starting | +: : : indices of the slices we : +: : : gather. : +| `index_vector_dim` | `int64` | The dimension in | +: : : `start_indices` that : +: : : "contains" the starting : +: : : indices. See below for a : +: : : detailed description. : +| `offset_dims` | `ArraySlice` | The set of dimensions in the | +: : : output shape that offset into : +: : : a array sliced from operand. : +| `slice_sizes` | `ArraySlice` | `slice_sizes[i]` is the | +: : : bounds for the slice on : +: : : dimension `i`. : +| `collapsed_slice_dims` | `ArraySlice` | The set of dimensions in each | +: : : \: slice that are collapsed : +: : : away. These dimensions must : +: : : have size 1. : +| `start_index_map` | `ArraySlice` | A map that describes how to | +: : : map indices in : +: : : `start_indices` to legal : +: : : indices into operand. : For convenience, we label dimensions in the output array not in `offset_dims` as `batch_dims`. @@ -1269,7 +1315,7 @@ the output shape, and maps it to an element in the input array in the following way:
- +
We first select an (`X`,`Y`) vector from the gather indices array using `G`. @@ -1288,7 +1334,7 @@ version of the example above using a "gather indices" array of shape `[4,5,2]` would translate indices like this:
- +
Again, this acts as a batch dynamic slice `G``0` and @@ -1317,7 +1363,7 @@ the following ways: As a final example, we use (2) and (3) to implement `tf.gather_nd`:
- +
`G``0` and `G``1` are used to slice out a starting index @@ -1326,7 +1372,7 @@ element, `X`. Similarly, there is only one output offset index with the value `O``0`. However, before being used as indices into the input array, these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal -description) into [`0`,`O``0`] and [`X`,`0`] respectively, adding up +description) into [`X`,`0`] and [`0`,`O``0`] respectively, adding up to [`X`,`O``0`]. In other words, the output index [`G``0`,`G``1`,`O``0`] maps to the input index [`GatherIndices`[`G``0`,`G``1`,`0`],`X`] which gives us @@ -1336,6 +1382,22 @@ the semantics for `tf.gather_nd`. index `X` in the gather indices array picks an entire row and the result is the concatenation of all these rows. +## GetDimensionSize + +See also +[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Returns the size of the given dimension of the operand. The operand must be +array shaped. + + `GetDimensionSize(operand, dimension)` + +| Arguments | Type | Semantics | +| ----------- | ------- | --------------------------------------------------- | +| `operand` | `XlaOp` | n dimensional input array | +| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | +: : : dimension : + ## GetTupleElement See also @@ -1401,13 +1463,14 @@ Infeed of the device. `Iota()` Builds a constant literal on device rather than a potentially large host -transfer. Creates a rank 1 tensor of values starting at zero and incrementing -by one. +transfer. Creates a rank 1 array of values starting at zero and incrementing by +one. -Arguments | Type | Semantics ------------------- | --------------- | --------------------------- -`type` | `PrimitiveType` | type U -`size` | `int64` | The number of elements in the tensor. +Arguments | Type | Semantics +---------------- | --------------- | ------------------------------------ +`type` | `PrimitiveType` | type U +`size` | `int64` | The number of elements in the array. +`iota_dimension` | `int64` | The dimension to increment along. ## Map @@ -1461,20 +1524,25 @@ dimension. `PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and -`interior_padding`. `edge_padding_low` and `edge_padding_high` specify the -amount of padding added at the low-end (next to index 0) and the high-end (next -to the highest index) of each dimension respectively. The amount of edge padding -can be negative -- the absolute value of negative padding indicates the number -of elements to remove from the specified dimension. `interior_padding` specifies -the amount of padding added between any two elements in each dimension. Interior -padding occurs logically before edge padding, so in the case of negative edge -padding elements are removed from the interior-padded operand. This operation is -a no-op if the edge padding pairs are all (0, 0) and the interior padding values -are all 0. The figure below shows examples of different `edge_padding` and -`interior_padding` values for a two-dimensional array. +`interior_padding`. + +`edge_padding_low` and `edge_padding_high` specify the amount of padding added +at the low-end (next to index 0) and the high-end (next to the highest index) of +each dimension respectively. The amount of edge padding can be negative -- the +absolute value of negative padding indicates the number of elements to remove +from the specified dimension. + +`interior_padding` specifies the amount of padding added between any two +elements in each dimension; it may not be negative. Interior padding occurs +logically before edge padding, so in the case of negative edge padding, elements +are removed from the interior-padded operand. + +This operation is a no-op if the edge padding pairs are all (0, 0) and the +interior padding values are all 0. The figure below shows examples of different +`edge_padding` and `interior_padding` values for a two-dimensional array.
- +
## Recv @@ -1590,13 +1658,13 @@ Here's an example of reducing a 2D array (matrix). The shape has rank 2, dimension 0 of size 2 and dimension 1 of size 3:
- +
Results of reducing dimensions 0 or 1 with an "add" function:
- +
Note that both reduction results are 1D arrays. The diagram shows one as column @@ -1607,7 +1675,7 @@ size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the values 1 to 6 are replicated across dimension 0.
- +
Similarly to the 2D example, we can reduce just one dimension. If we reduce @@ -1640,8 +1708,8 @@ Reducing the 3D array over all its dimensions produces the scalar `84`. When `N > 1`, reduce function application is slightly more complex, as it is applied simultaneously to all inputs. For example, consider the following -reduction function, which can be used to compute the max and the argmax of a -a 1-D tensor in parallel: +reduction function, which can be used to compute the max and the argmax of a a +1-D array in parallel: ``` f: (Float, Int, Float, Int) -> Float, Int @@ -1728,6 +1796,10 @@ window_strides, padding)` : : : dimension values : | `window_strides` | `ArraySlice` | array of integers for window | : : : stride values : +| `base_dilations` | `ArraySlice` | array of integers for base | +: : : dilation values : +| `window_dilations` | `ArraySlice` | array of integers for window | +: : : dilation values : | `padding` | `Padding` | padding type for window | : : : (Padding\:\:kSame or : : : : Padding\:\:kValid) : @@ -1752,15 +1824,16 @@ XlaBuilder builder(client_, "reduce_window_2x3"); auto shape = ShapeUtil::MakeShape(F32, {4, 6}); auto input = builder.Parameter(0, shape, "input"); builder.ReduceWindow( - input, *max, + input, /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)), + *max, /*window_dimensions=*/{2, 3}, /*window_stride_dimensions=*/{2, 3}, Padding::kValid); ```
- +
Stride of 1 in a dimension specifies that the position of a window in the @@ -1772,7 +1845,7 @@ are the same as though the input came in with the dimensions it has after padding.
- +
The evaluation order of the reduction function is arbitrary and may be @@ -1929,44 +2002,24 @@ implementation-defined. ## Scatter The XLA scatter operation generates a result which is the value of the input -tensor `operand`, with several slices (at indices specified by -`scatter_indices`) updated with the values in `updates` using -`update_computation`. +array `operand`, with several slices (at indices specified by `scatter_indices`) +updated with the values in `updates` using `update_computation`. See also [`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` -|Arguments | Type | Semantics | -|------------------|------------------------|----------------------------------| -|`operand` | `XlaOp` | Tensor to be scattered into. | -|`scatter_indices` | `XlaOp` | Tensor containing the starting | -: : : indices of the slices that must : -: : : be scattered to. : -|`updates` | `XlaOp` | Tensor containing the values that| -: : : must be used for scattering. : -|`update_computation`| `XlaComputation` | Computation to be used for | -: : : combining the existing values in : -: : : the input tensor and the updates : -: : : during scatter. This computation : -: : : should be of type `T, T -> T`. : -|`index_vector_dim`| `int64` | The dimension in | -: : : `scatter_indices` that contains : -: : : the starting indices. : -|`update_window_dims`| `ArraySlice` | The set of dimensions in | -: : : `updates` shape that are _window : -: : : dimensions_. : -|`inserted_window_dims`| `ArraySlice`| The set of _window dimensions_ | -: : : that must be inserted into : -: : : `updates` shape. : -|`scatter_dims_to_operand_dims`| `ArraySlice` | A dimensions map from | -: : : the scatter indices to the : -: : : operand index space. This array : -: : : is interpreted as mapping `i` to : -: : : `scatter_dims_to_operand_dims[i]`: -: : : . It has to be one-to-one and : -: : : total. : +Arguments | Type | Semantics +------------------------------ | ------------------- | --------- +`operand` | `XlaOp` | Array to be scattered into. +`scatter_indices` | `XlaOp` | Array containing the starting indices of the slices that must be scattered to. +`updates` | `XlaOp` | Array containing the values that must be used for scattering. +`update_computation` | `XlaComputation` | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `T, T -> T`. +`index_vector_dim` | `int64` | The dimension in `scatter_indices` that contains the starting indices. +`update_window_dims` | `ArraySlice` | The set of dimensions in `updates` shape that are _window dimensions_. +`inserted_window_dims` | `ArraySlice` | The set of _window dimensions_ that must be inserted into `updates` shape. +`scatter_dims_to_operand_dims` | `ArraySlice` | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total. If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider `scatter_indices` to have a trailing `1` dimension. @@ -1977,78 +2030,77 @@ order. The arguments of scatter should follow these constraints: - - `updates` tensor must be of rank `update_window_dims.size + - scatter_indices.rank - 1`. +- `updates` array must be of rank `update_window_dims.size + + scatter_indices.rank - 1`. + +- Bounds of dimension `i` in `updates` must conform to the following: - - Bounds of dimension `i` in `updates` must conform to the following: - - If `i` is present in `update_window_dims` (i.e. equal to - `update_window_dims`[`k`] for some `k`), then the bound of dimension - `i` in `updates` must not exceed the corresponding bound of `operand` - after accounting for the `inserted_window_dims` (i.e. + - If `i` is present in `update_window_dims` (i.e. equal to + `update_window_dims`[`k`] for some `k`), then the bound of dimension `i` + in `updates` must not exceed the corresponding bound of `operand` after + accounting for the `inserted_window_dims` (i.e. `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains the bounds of `operand` with the bounds at indices `inserted_window_dims` removed). - - If `i` is present in `update_scatter_dims` (i.e. equal to + - If `i` is present in `update_scatter_dims` (i.e. equal to `update_scatter_dims`[`k`] for some `k`), then the bound of dimension `i` in `updates` must be equal to the corresponding bound of `scatter_indices`, skipping `index_vector_dim` (i.e. `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and `scatter_indices.shape.dims`[`k+1`] otherwise). - - `update_window_dims` must be in ascending order, not have any repeating +- `update_window_dims` must be in ascending order, not have any repeating dimension numbers, and be in the range `[0, updates.rank)`. - - `inserted_window_dims` must be in ascending order, not have any - repeating dimension numbers, and be in the range `[0, operand.rank)`. +- `inserted_window_dims` must be in ascending order, not have any repeating + dimension numbers, and be in the range `[0, operand.rank)`. - - `scatter_dims_to_operand_dims.size` must be equal to +- `scatter_dims_to_operand_dims.size` must be equal to `scatter_indices`[`index_vector_dim`], and its values must be in the range `[0, operand.rank)`. -For a given index `U` in the `updates` tensor, the corresponding index `I` in -the `operand` tensor into which this update has to be applied is computed as -follows: - - 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up - an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] = - `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at - positions `index_vector_dim` into A. - 2. Create an index `S``in` into `operand` using `S` by scattering - `S` using the `scatter_dims_to_operand_dims` map. More formally: - 1. `S``in`[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if - `k` < `scatter_dims_to_operand_dims.size`. - 2. `S``in`[`_`] = `0` otherwise. - 3. Create an index `W``in` into `operand` by scattering the indices - at `update_window_dims` in `U` according to `inserted_window_dims`. - More formally: - 1. `W``in`[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if - `k` < `update_window_dims.size`, where `window_dims_to_operand_dims` - is the monotonic function with domain [`0`, `update_window_dims.size`) - and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For - example, if `update_window_dims.size` is `4`, `operand.rank` is `6`, - and `inserted_window_dims` is {`0`, `2`} then - `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, - `3`→`5`}). - 2. `W``in`[`_`] = `0` otherwise. - 4. `I` is `W``in` + `S``in` where + is element-wise - addition. +For a given index `U` in the `updates` array, the corresponding index `I` in the +`operand` array into which this update has to be applied is computed as follows: + +1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up + an index vector `S` in the `scatter_indices` array such that `S`[`i`] = + `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at + positions `index_vector_dim` into A. +2. Create an index `S``in` into `operand` using `S` by scattering + `S` using the `scatter_dims_to_operand_dims` map. More formally: + 1. `S``in`[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if + `k` < `scatter_dims_to_operand_dims.size`. + 2. `S``in`[`_`] = `0` otherwise. +3. Create an index `W``in` into `operand` by scattering the indices + at `update_window_dims` in `U` according to `inserted_window_dims`. More + formally: + 1. `W``in`[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if `k` + < `update_window_dims.size`, where `window_dims_to_operand_dims` is the + monotonic function with domain [`0`, `update_window_dims.size`) and + range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For example, if + `update_window_dims.size` is `4`, `operand.rank` is `6`, and + `inserted_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims` + is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}). + 2. `W``in`[`_`] = `0` otherwise. +4. `I` is `W``in` + `S``in` where + is element-wise + addition. In summary, the scatter operation can be defined as follows. - - Initialize `output` with `operand`, i.e. for all indices `O` in the - `operand` tensor:\ - `output`[`O`] = `operand`[`O`] - - For every index `U` in the `updates` tensor and the corresponding index `O` - in the `operand` tensor:\ - `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) +- Initialize `output` with `operand`, i.e. for all indices `O` in the + `operand` array: \ + `output`[`O`] = `operand`[`O`] +- For every index `U` in the `updates` array and the corresponding index `O` + in the `operand` array: \ + `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) The order in which updates are applied is non-deterministic. So, when multiple indices in `updates` refer to the same index in `operand`, the corresponding value in `output` will be non-deterministic. Note that the first parameter that is passed into the `update_computation` will -always be the current value from the `output` tensor and the second parameter -will always be the value from the `updates` tensor. This is important +always be the current value from the `output` array and the second parameter +will always be the value from the `updates` array. This is important specifically for cases when the `update_computation` is _not commutative_. Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. @@ -2080,10 +2132,9 @@ shape of the output array. The array `pred` must have the same dimensionality as For each element `P` of `pred`, the corresponding element of the output array is taken from `on_true` if the value of `P` is `true`, and from `on_false` if the -value of `P` is `false`. As a restricted form of [broadcasting] -(broadcasting.md), `pred` can be a scalar of type `PRED`. In this case, the -output array is taken wholly from `on_true` if `pred` is `true`, and from -`on_false` if `pred` is `false`. +value of `P` is `false`. As a restricted form of [broadcasting](broadcasting.md), +`pred` can be a scalar of type `PRED`. In this case, the output array is taken +wholly from `on_true` if `pred` is `true`, and from `on_false` if `pred` is `false`. Example with non-scalar `pred`: @@ -2181,7 +2232,7 @@ addition `scatter` function produces the output element of value 8 (2 + 6).
+ src="./images/ops_scatter_to_selected_window_element.png">
The evaluation order of the `scatter` function is arbitrary and may be @@ -2228,7 +2279,7 @@ The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`, `Send`, `SendDone`) is as below.
- +
* `Recv` happens before `Send` @@ -2241,7 +2292,7 @@ communicates via channel instructions, there must not be cycles across the computations. For example, below schedules lead to deadlocks.
- +
## Slice @@ -2299,9 +2350,9 @@ See also [`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). There are two versions of the Sort instruction: a single-operand and a -two-operand version. +multi-operand version. -`Sort(operand)` +`Sort(operand, dimension)` Arguments | Type | Semantics ----------- | ------- | -------------------- @@ -2315,25 +2366,26 @@ row independently. If the operand's elements have floating point type, and the operand contains NaN elements, the order of elements in the output is implementation-defined. -`Sort(key, value)` +`Sort(keys, values, ... values, dimension)` -Sorts both the key and the value operands. The keys are sorted as in the -single-operand version. The values are sorted according to the order of their -corresponding keys. For example, if the inputs are `keys = [3, 1]` and -`values = [42, 50]`, then the output of the sort is the tuple -`{[1, 3], [50, 42]}`. +Sorts both the key and one or more value operands. The keys are sorted as in the +single-operand version. Each of the values inputs is sorted according to the +order of the corresponding keys. For example, if the three inputs are `keys = +[3, 1]`, `values0 = [42, 50]`, `values1 = [-3.0, 1.1]`, then the output of the +sort is the tuple `{[1, 3], [50, 42], [1.1, -3.0]}`. The sort is not guaranteed to be stable, that is, if the keys array contains -duplicates, the order of their corresponding values may not be preserved. +duplicates, the order of values corresponding to these keys may not be +preserved. -Arguments | Type | Semantics ------------ | ------- | ------------------- -`keys` | `XlaOp` | The sort keys. -`values` | `XlaOp` | The values to sort. -`dimension` | `int64` | The dimension along which to sort. +Arguments | Type | Semantics +----------- | ---------------------- | ---------------------------------- +`keys` | `XlaOp` | The sort keys. +`values` | Sequence of N `XlaOp`s | The values to sort. +`dimension` | `int64` | The dimension along which to sort. -The `keys` and `values` must have the same dimensions, but may have different -element types. +The `keys` and each of the `values` inputs must have the same dimensions, but +may have different element types. ## Transpose @@ -2422,5 +2474,5 @@ while (result(0) < 1000) { ```
- +
diff --git a/tensorflow/compiler/xla/g3doc/overview.md b/tensorflow/compiler/xla/g3doc/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..d3428b7276131e8f406f60cfea9a9346c5478433 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/overview.md @@ -0,0 +1,98 @@ +# XLA Overview + +
+ +
+ +> Note: XLA is still under development. Some use cases will not +> see improvements in speed or decreased memory usage. + +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear +algebra that optimizes TensorFlow computations. The results are improvements in +speed, memory usage, and portability on server and mobile platforms. Initially, +most users will not see large benefits from XLA, but are welcome to experiment +by using XLA via [just-in-time (JIT) compilation](./jit.md) or +[ahead-of-time (AOT) compilation](./tfcompile.md). Developers targeting new +hardware accelerators are especially encouraged to try out XLA. + +The XLA framework is experimental and in active development. In particular, +while it is unlikely that the semantics of existing operations will change, it +is expected that more operations will be added to cover important use cases. The +team welcomes feedback from the community about missing functionality and +community contributions via GitHub. + +## Why did we build XLA? + +We had several objectives for XLA to work with TensorFlow: + +* *Improve execution speed.* Compile subgraphs to reduce the execution time of + short-lived Ops to eliminate overhead from the TensorFlow runtime, fuse + pipelined operations to reduce memory overhead, and specialize to known + tensor shapes to allow for more aggressive constant propagation. + +* *Improve memory usage.* Analyze and schedule memory usage, in principle + eliminating many intermediate storage buffers. + +* *Reduce reliance on custom Ops.* Remove the need for many custom Ops by + improving the performance of automatically fused low-level Ops to match the + performance of custom Ops that were fused by hand. + +* *Reduce mobile footprint.* Eliminate the TensorFlow runtime by ahead-of-time + compiling the subgraph and emitting an object/header file pair that can be + linked directly into another application. The results can reduce the + footprint for mobile inference by several orders of magnitude. + +* *Improve portability.* Make it relatively easy to write a new backend for + novel hardware, at which point a large fraction of TensorFlow programs will + run unmodified on that hardware. This is in contrast with the approach of + specializing individual monolithic Ops for new hardware, which requires + TensorFlow programs to be rewritten to make use of those Ops. + +## How does XLA work? + +The input language to XLA is called "HLO IR", or just HLO (High Level +Optimizer). The semantics of HLO are described on the +[Operation Semantics](./operation_semantics.md) page. It +is most convenient to think of HLO as a +[compiler IR](https://en.wikipedia.org/wiki/Intermediate_representation). + +XLA takes graphs ("computations") defined in HLO and compiles them into machine +instructions for various architectures. XLA is modular in the sense that it is +easy to slot in an alternative backend to +[target some novel HW architecture](./developing_new_backend.md). +The CPU backend for x64 and ARM64 as well as the NVIDIA GPU backend are in the +TensorFlow source tree. + +The following diagram shows the compilation process in XLA: + +
+ +
+ +XLA comes with several optimizations and analysis passes that are +target-independent, such as +[CSE](https://en.wikipedia.org/wiki/Common_subexpression_elimination), +target-independent operation fusion, and buffer analysis for allocating runtime +memory for the computation. + +After the target-independent step, XLA sends the HLO computation to a backend. +The backend can perform further HLO-level optimizations, this time with target +specific information and needs in mind. For example, the XLA GPU backend may +perform operation fusion beneficial specifically for the GPU programming model +and determine how to partition the computation into streams. At this stage, +backends may also pattern-match certain operations or combinations thereof to +optimized library calls. + +The next step is target-specific code generation. The CPU and GPU backends +included with XLA use [LLVM](http://llvm.org) for low-level IR, optimization, +and code-generation. These backends emit the LLVM IR necessary to represent the +XLA HLO computation in an efficient manner, and then invoke LLVM to emit native +code from this LLVM IR. + +The GPU backend currently supports NVIDIA GPUs via the LLVM NVPTX backend; the +CPU backend supports multiple CPU ISAs. + +## Supported Platforms + +XLA currently supports [JIT compilation](./jit.md) on x86-64 and NVIDIA GPUs; and +[AOT compilation](./tfcompile.md) for x86-64 and ARM. diff --git a/tensorflow/compiler/xla/g3doc/shapes.md b/tensorflow/compiler/xla/g3doc/shapes.md new file mode 100644 index 0000000000000000000000000000000000000000..39e74ff307cde49ef378a1201cb074dce4ababf0 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/shapes.md @@ -0,0 +1,150 @@ +# Shapes and Layout + +The XLA `Shape` proto +([xla_data.proto](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto)) +describes the rank, size, and data type of an N-dimensional array (*array* in +short). + +## Terminology, Notation, and Conventions + +* The rank of an array is equal to the number of dimensions. The *true rank* + of an array is the number of dimensions which have a size greater than 1. + +* Dimensions are numbered from `0` up to `N-1` for an `N` dimensional array. + The dimension numbers are arbitrary labels for convenience. The order of + these dimension numbers does not imply a particular minor/major ordering in + the layout of the shape. The layout is determined by the `Layout` proto. + +* By convention, dimensions are listed in increasing order of dimension + number. For example, for a 3-dimensional array of size `[A x B x C]`, + dimension 0 has size `A`, dimension 1 has size `B` and dimension 2 has size + `C`. + + Some utilities in XLA also support negative indexing, similarly to Python; + dimension -1 is the last dimension (equivalent to `N-1` for an `N` + dimensional array). For example, for the 3-dimensional array described + above, dimension -1 has size `C`, dimension -2 has size `B` and so on. + +* Two, three, and four dimensional arrays often have specific letters + associated with dimensions. For example, for a 2D array: + + * dimension 0: `y` + * dimension 1: `x` + + For a 3D array: + + * dimension 0: `z` + * dimension 1: `y` + * dimension 2: `x` + + For a 4D array: + + * dimension 0: `p` + * dimension 1: `z` + * dimension 2: `y` + * dimension 3: `x` + +* Functions in the XLA API which take dimensions do so in increasing order of + dimension number. This matches the ordering used when passing dimensions as + an `initializer_list`; e.g. + + `ShapeUtil::MakeShape(F32, {A, B, C, D})` + + Will create a shape whose dimension size array consists of the sequence + `[A, B, C, D]`. + +## Layout + +The `Layout` proto describes how an array is represented in memory. The `Layout` +proto includes the following fields: + +``` +message Layout { + repeated int64 minor_to_major = 1; + repeated int64 padded_dimensions = 2; + optional PaddingValue padding_value = 3; +} +``` + +### Minor-to-major dimension ordering + +The only required field is `minor_to_major`. This field describes the +minor-to-major ordering of the dimensions within a shape. Values in +`minor_to_major` are an ordering of the dimensions of the array (`0` to `N-1` +for an `N` dimensional array) with the first value being the most-minor +dimension up to the last value which is the most-major dimension. The most-minor +dimension is the dimension which changes most rapidly when stepping through the +elements of the array laid out in linear memory. + +For example, consider the following 2D array of size `[2 x 3]`: + +``` +a b c +d e f +``` + +Here dimension `0` is size 2, and dimension `1` is size 3. If the +`minor_to_major` field in the layout is `[0, 1]` then dimension `0` is the +most-minor dimension and dimension `1` is the most-major dimension. This +corresponds to the following layout in linear memory: + +``` +a d b e c f +``` + +This minor-to-major dimension order of `0` up to `N-1` is akin to *column-major* +(at rank 2). Assuming a monotonic ordering of dimensions, another name we may +use to refer to this layout in the code is simply "dim 0 is minor". + +On the other hand, if the `minor_to_major` field in the layout is `[1, 0]` then +the layout in linear memory is: + +``` +a b c d e f +``` + +A minor-to-major dimension order of `N-1` down to `0` for an `N` dimensional +array is akin to *row-major* (at rank 2). Assuming a monotonic ordering of +dimensions, another name we may use to refer to this layout in the code is +simply "dim 0 is major". + +#### Default minor-to-major ordering + +The default layout for newly created Shapes is "dimension order is +major-to-minor" (akin to row-major at rank 2). + +### Padding + +Padding is defined in the optional `padded_dimensions` and `padding_value` +fields. The field `padded_dimensions` describes the sizes (widths) to which each +dimension is padded. If present, the number of elements in `padded_dimensions` +must equal the rank of the shape. + +For example, given the `[2 x 3]` array defined above, if `padded_dimension` is +`[3, 5]` then dimension 0 is padded to a width of 3 and dimension 1 is padded to +a width of 5. The layout in linear memory (assuming a padding value of 0 and +column-major layout) is: + +``` +a d 0 b e 0 c f 0 0 0 0 0 0 0 +``` + +This is equivalent to the layout of the following array with the same +minor-to-major dimension order: + +``` +a b c 0 0 +d e f 0 0 +0 0 0 0 0 +``` + +### Indexing into arrays + +The class `IndexUtil` in +[index_util.h](https://www.tensorflow.org/code/tensorflow/compiler/xla/index_util.h) +provides utilities for converting between multidimensional indices and linear +indices given a shape and layout. Multidimensional indices include a `int64` +index for each dimension. Linear indices are a single `int64` value which +indexes into the buffer holding the array. See `shape_util.h` and +`layout_util.h` in the same directory for utilities that simplify creation and +manipulation of shapes and layouts. diff --git a/tensorflow/compiler/xla/g3doc/tfcompile.md b/tensorflow/compiler/xla/g3doc/tfcompile.md new file mode 100644 index 0000000000000000000000000000000000000000..5ee09fd302ba0edf84a7c99bb369586067141bef --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tfcompile.md @@ -0,0 +1,281 @@ +# Using AOT compilation + +## What is tfcompile? + +`tfcompile` is a standalone tool that ahead-of-time (AOT) compiles TensorFlow +graphs into executable code. It can reduce total binary size, and also avoid +some runtime overheads. A typical use-case of `tfcompile` is to compile an +inference graph into executable code for mobile devices. + +The TensorFlow graph is normally executed by the TensorFlow runtime. This incurs +some runtime overhead for execution of each node in the graph. This also leads +to a larger total binary size, since the code for the TensorFlow runtime needs +to be available, in addition to the graph itself. The executable code produced +by `tfcompile` does not use the TensorFlow runtime, and only has dependencies on +kernels that are actually used in the computation. + +The compiler is built on top of the XLA framework. The code bridging TensorFlow +to the XLA framework resides under +[tensorflow/compiler](https://www.tensorflow.org/code/tensorflow/compiler/), +which also includes support for [just-in-time (JIT) compilation](jit.md) of +TensorFlow graphs. + +## What does tfcompile do? + +`tfcompile` takes a subgraph, identified by the TensorFlow concepts of +feeds and fetches, and generates a function that implements that subgraph. +The `feeds` are the input arguments for the function, and the `fetches` are the +output arguments for the function. All inputs must be fully specified by the +feeds; the resulting pruned subgraph cannot contain Placeholder or Variable +nodes. It is common to specify all Placeholders and Variables as feeds, which +ensures the resulting subgraph no longer contains these nodes. The generated +function is packaged as a `cc_library`, with a header file exporting the +function signature, and an object file containing the implementation. The user +writes code to invoke the generated function as appropriate. + +## Using tfcompile + +This section details high level steps for generating an executable binary with +`tfcompile` from a TensorFlow subgraph. The steps are: + +* Step 1: Configure the subgraph to compile +* Step 2: Use the `tf_library` build macro to compile the subgraph +* Step 3: Write code to invoke the subgraph +* Step 4: Create the final binary + +### Step 1: Configure the subgraph to compile + +Identify the feeds and fetches that correspond to the input and output +arguments for the generated function. Then configure the `feeds` and `fetches` +in a [`tensorflow.tf2xla.Config`](https://www.tensorflow.org/code/tensorflow/compiler/tf2xla/tf2xla.proto) +proto. + +```textproto +# Each feed is a positional input argument for the generated function. The order +# of each entry matches the order of each input argument. Here “x_hold” and “y_hold” +# refer to the names of placeholder nodes defined in the graph. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 2 } + dim { size: 3 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 3 } + dim { size: 2 } + } +} + +# Each fetch is a positional output argument for the generated function. The order +# of each entry matches the order of each output argument. Here “x_y_prod” +# refers to the name of a matmul node defined in the graph. +fetch { + id { node_name: "x_y_prod" } +} +``` + +### Step 2: Use tf_library build macro to compile the subgraph + +This step converts the graph into a `cc_library` using the `tf_library` build +macro. The `cc_library` consists of an object file containing the code generated +from the graph, along with a header file that gives access to the generated +code. `tf_library` utilizes `tfcompile` to compile the TensorFlow graph into +executable code. + +```build +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# Use the tf_library macro to compile your graph into executable code. +tf_library( + # name is used to generate the following underlying build rules: + # : cc_library packaging the generated header and object files + # _test : cc_test containing a simple test and benchmark + # _benchmark : cc_binary containing a stand-alone benchmark with minimal deps; + # can be run on a mobile device + name = "test_graph_tfmatmul", + # cpp_class specifies the name of the generated C++ class, with namespaces allowed. + # The class will be generated in the given namespace(s), or if no namespaces are + # given, within the global namespace. + cpp_class = "foo::bar::MatMulComp", + # graph is the input GraphDef proto, by default expected in binary format. To + # use the text format instead, just use the ‘.pbtxt’ suffix. A subgraph will be + # created from this input graph, with feeds as inputs and fetches as outputs. + # No Placeholder or Variable ops may exist in this subgraph. + graph = "test_graph_tfmatmul.pb", + # config is the input Config proto, by default expected in binary format. To + # use the text format instead, use the ‘.pbtxt’ suffix. This is where the + # feeds and fetches were specified above, in the previous step. + config = "test_graph_tfmatmul.config.pbtxt", +) +``` + +> To generate the GraphDef proto (test_graph_tfmatmul.pb) for this example, run +> [make_test_graphs.py](https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/make_test_graphs.py) +> and specify the output location with the --out_dir flag. + +Typical graphs contain [`Variables`](https://www.tensorflow.org/guide/variables) +representing the weights that are learned via training, but `tfcompile` cannot +compile a subgraph that contain `Variables`. The +[freeze_graph.py](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py) +tool converts variables into constants, using values stored in a checkpoint +file. As a convenience, the `tf_library` macro supports the `freeze_checkpoint` +argument, which runs the tool. For more examples see +[tensorflow/compiler/aot/tests/BUILD](https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/BUILD). + +> Constants that show up in the compiled subgraph are compiled directly into the +> generated code. To pass the constants into the generated function, rather than +> having them compiled-in, simply pass them in as feeds. + +For details on the `tf_library` build macro, see +[tfcompile.bzl](https://www.tensorflow.org/code/tensorflow/compiler/aot/tfcompile.bzl). + +For details on the underlying `tfcompile` tool, see +[tfcompile_main.cc](https://www.tensorflow.org/code/tensorflow/compiler/aot/tfcompile_main.cc). + +### Step 3: Write code to invoke the subgraph + +This step uses the header file (`test_graph_tfmatmul.h`) generated by the +`tf_library` build macro in the previous step to invoke the generated code. The +header file is located in the `bazel-genfiles` directory corresponding to the +build package, and is named based on the name attribute set on the `tf_library` +build macro. For example, the header generated for `test_graph_tfmatmul` would +be `test_graph_tfmatmul.h`. Below is an abbreviated version of what is +generated. The generated file, in `bazel-genfiles`, contains additional useful +comments. + +```c++ +namespace foo { +namespace bar { + +// MatMulComp represents a computation previously specified in a +// TensorFlow graph, now compiled into executable code. +class MatMulComp { + public: + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + ARGS_RESULTS_AND_TEMPS, // Allocate arg, result and temp buffers + RESULTS_AND_TEMPS_ONLY, // Only allocate result and temp buffers + }; + + MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + ~MatMulComp(); + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run(); + + // Arg methods for managing input buffers. Buffers are in row-major order. + // There is a set of methods for each positional argument. + void** args(); + + void set_arg0_data(float* data); + float* arg0_data(); + float& arg0(size_t dim0, size_t dim1); + + void set_arg1_data(float* data); + float* arg1_data(); + float& arg1(size_t dim0, size_t dim1); + + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. There is a set of methods + // for each positional result. + void** results(); + + + float* result0_data(); + float& result0(size_t dim0, size_t dim1); +}; + +} // end namespace bar +} // end namespace foo +``` + +The generated C++ class is called `MatMulComp` in the `foo::bar` namespace, +because that was the `cpp_class` specified in the `tf_library` macro. All +generated classes have a similar API, with the only difference being the methods +to handle arg and result buffers. Those methods differ based on the number and +types of the buffers, which were specified by the `feed` and `fetch` arguments +to the `tf_library` macro. + +There are three types of buffers managed within the generated class: `args` +representing the inputs, `results` representing the outputs, and `temps` +representing temporary buffers used internally to perform the computation. By +default, each instance of the generated class allocates and manages all of these +buffers for you. The `AllocMode` constructor argument may be used to change this +behavior. All buffers are aligned to 64-byte boundaries. + +The generated C++ class is just a wrapper around the low-level code generated by +XLA. + +Example of invoking the generated function based on +[`tfcompile_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/aot/tests/tfcompile_test.cc): + +```c++ +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated + +int main(int argc, char** argv) { + Eigen::ThreadPool tp(2); // Size the thread pool as appropriate. + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + + foo::bar::MatMulComp matmul; + matmul.set_thread_pool(&device); + + // Set up args and run the computation. + const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::copy(args + 0, args + 6, matmul.arg0_data()); + std::copy(args + 6, args + 12, matmul.arg1_data()); + matmul.Run(); + + // Check result + if (matmul.result0(0, 0) == 58) { + std::cout << "Success" << std::endl; + } else { + std::cout << "Failed. Expected value 58 at 0,0. Got:" + << matmul.result0(0, 0) << std::endl; + } + + return 0; +} +``` + +### Step 4: Create the final binary + +This step combines the library generated by `tf_library` in step 2 and the code +written in step 3 to create a final binary. Below is an example `bazel` BUILD +file. + +```build +# Example of linking your binary +# Also see //tensorflow/compiler/aot/tests/BUILD +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") + +# The same tf_library call from step 2 above. +tf_library( + name = "test_graph_tfmatmul", + ... +) + +# The executable code generated by tf_library can then be linked into your code. +cc_binary( + name = "my_binary", + srcs = [ + "my_code.cc", # include test_graph_tfmatmul.h to access the generated header + ], + deps = [ + ":test_graph_tfmatmul", # link in the generated object file + "//third_party/eigen3", + ], + linkopts = [ + "-lpthread", + ] +) +``` diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2a83092805be5efdd7b9ab54449b2bcc6a2ec481 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -0,0 +1,373 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "The XLA compile API", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "metadata": { + "colab_type": "text", + "id": "f4TSNCvpENrW" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors." + ] + }, + { + "metadata": { + "cellView": "form", + "colab_type": "code", + "id": "vamNSA0vEP-m", + "colab": {} + }, + "cell_type": "code", + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "e1oSi4lHFt3z" + }, + "cell_type": "markdown", + "source": [ + "# The XLA compile API" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "cell_type": "markdown", + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "v9YbsuLZaBXy" + }, + "cell_type": "markdown", + "source": [ + "\n", + "\n", + "Import TensorFlow and the XLA library. XLA contains `xla.compile()`, an experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/)." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "45kUPj5ZFrRa", + "colab": {} + }, + "cell_type": "code", + "source": [ + "import tensorflow as tf\n", + "\n", + "from tensorflow.contrib.compiler import xla" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "GZVNiRmTDV-5" + }, + "cell_type": "markdown", + "source": [ + "Define some necessary constants and prepare the MNIST dataset." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "f37TSEGvGX4_", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Size of each input image, 28 x 28 pixels\n", + "IMAGE_SIZE = 28 * 28\n", + "# Number of distinct number labels, [0..9]\n", + "NUM_CLASSES = 10\n", + "# Number of examples in each training batch (step)\n", + "TRAIN_BATCH_SIZE = 100\n", + "# Number of training steps to run\n", + "TRAIN_STEPS = 1000" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "code", + "id": "TiVXchblG5hK", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Loads MNIST dataset.\n", + "train, test = tf.keras.datasets.mnist.load_data()\n", + "train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()\n", + "test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)\n", + "\n", + "iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)\n", + "images, labels = iterator.get_next()\n", + "images = tf.reshape(images, [-1, IMAGE_SIZE])\n", + "images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "x_ZehpZP-SfS" + }, + "cell_type": "markdown", + "source": [ + "# Define the model constructing function\n", + "\n", + "Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.\n", + "\n", + "When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "ZbhJl_WvGa3g", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def build_mnist_model(x, y_):\n", + " y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)\n", + "\n", + " cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)\n", + " train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", + "\n", + " return y, train_step" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "7Jh3lyQHDfM9" + }, + "cell_type": "markdown", + "source": [ + "# Enable XLA\n", + "\n", + "Use `xla.compile` with the `build_mnist_model` function to enable XLA. Following code block wraps the model with `xla.compile()`, which allows the target function with provided inputs to be executed by XLA." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "kYpCXCdRHNuN", + "colab": {} + }, + "cell_type": "code", + "source": [ + "[y] = xla.compile(build_mnist_model, inputs=[images, labels])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "4giQh62IrZGF" + }, + "cell_type": "markdown", + "source": [ + "When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.\n", + "\n", + "xla.compile does not return any\n", + "`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.\n", + "\n", + "In pseudo-code, xla.compile's implementation looks as follows:\n", + "\n", + "---\n", + "```\n", + "# Ask Tensorflow to execute code in XLA-friendly manner\n", + "\n", + "y, train_step = build_mnist_model(images, labels)\n", + "with tf.control_dependencies([train_step]):\n", + " y = tf.identity(y)\n", + "\n", + "# Ask Tensorflow to STOP executing code in XLA-friendly manner\n", + "```\n", + "---\n", + "\n", + "xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element)." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "TPGas4jjFLZl" + }, + "cell_type": "markdown", + "source": [ + "If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`. At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "EZD1m_n1DxAF" + }, + "cell_type": "markdown", + "source": [ + "# Train and test the model" + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "qe28bAHNHUG2", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Creates session and initialize all variables.\n", + "# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.\n", + "sess = tf.Session()\n", + "sess.run(tf.global_variables_initializer())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "qgsKmz3n2UiW" + }, + "cell_type": "markdown", + "source": [ + "Following code block trains model. Evaluating `y` also triggers its control dependency node `train_step`, which updates model variables." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "_GxF6jTRHVuA", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "fbf299ca-02d5-4e95-f9fe-8f3c0432d132" + }, + "cell_type": "code", + "source": [ + "# Feeds training dataset\n", + "sess.run(iterator.make_initializer(train_ds))\n", + "\n", + "# Runs TRAIN_STEPS steps\n", + "for i in range(TRAIN_STEPS):\n", + " sess.run(y)\n", + "\n", + "print(\"Model trained for %s steps.\" % TRAIN_STEPS)" + ], + "execution_count": 21, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Model trained for 1000 steps.\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "dHlQlRSRHXD1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "outputId": "9c3677a2-ec84-406f-9d2c-d722844f3093" + }, + "cell_type": "code", + "source": [ + "# Tests trained model\n", + "\n", + "# Feeds testing dataset\n", + "sess.run(iterator.make_initializer(test_ds))\n", + "\n", + "# Calculates accuracy\n", + "correct_prediction = tf.equal(tf.argmax(y, 1), labels)\n", + "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", + "print(\"Prediction accuracy after training: %s\" % sess.run(accuracy))" + ], + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Prediction accuracy after training: 0.91\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "ynJQIuzjHYOb", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Cleans up session\n", + "sess.close()" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 3fadabcf5207097aa875d654320b930b1ed94ad3..2a0241af3ef359c4d1c6c1ab9319b5b293110f7a 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -29,8 +29,6 @@ namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); - // Padding and nested layouts not supported yet. - DCHECK_EQ(0, shape.layout().padded_dimensions_size()); for (size_t i = 0; i < multi_index.size(); ++i) { DCHECK_GE(multi_index[i], 0); @@ -94,8 +92,6 @@ namespace xla { /* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( const Shape& shape, int64 linear_index) { - // Padding and nested layouts not supported yet. - DCHECK_EQ(0, shape.layout().padded_dimensions_size()); DCHECK_GE(linear_index, 0); DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape)); @@ -133,18 +129,12 @@ namespace xla { /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, int64 dimension) { - int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size(); int64 stride = 1; - DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); for (auto dim : LayoutUtil::MinorToMajor(shape)) { if (dim == dimension) { break; } - if (pdim_size == 0) { - stride *= shape.dimensions(dim); - } else { - stride *= LayoutUtil::PaddedDimension(shape, dim); - } + stride *= shape.dimensions()[dim]; } return stride; } diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 2979cf87dde92893ce2151cb09b46c8db8473b31..d76f61eb62c0fc89d6bc3ca2033e8c7170f30e78 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -61,8 +62,7 @@ class IndexUtil { static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a - // given logical shape dimension (from 0 to rank-1). If available, padded - // dimensions are used. + // given logical shape dimension (from 0 to rank-1). // Example: // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 93522d2ca87a7eba8d3c7533785c54e63ce507b0..fa94d0afb4c9280b8f8fa9642c1b0ab7285ee6f3 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -24,8 +24,7 @@ limitations under the License. namespace xla { namespace { -void SetMinorToMajorLayout(Shape* shape, - std::initializer_list dimensions) { +void SetMinorToMajorLayout(Shape* shape, std::vector dimensions) { shape->mutable_layout()->clear_minor_to_major(); for (auto dimension : dimensions) { shape->mutable_layout()->add_minor_to_major(dimension); @@ -122,7 +121,7 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { std::vector linear_indexes = {0, 1439999999, 1145567336, 43883404, 617295214, 1117613654}; - std::vector> minor_to_major_orders; + std::vector> minor_to_major_orders; minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0}); minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6}); minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3}); diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index d310335618ded7b581e6ed632223218585bb791f..dbb81381acde645f08639737b6e7b6f6ad971f9b 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; @@ -156,18 +162,23 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape( + const Shape& shape, bool allow_missing_layouts) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { return InvalidArgument("tuple should not have a layout field"); } for (auto& element_shape : shape.tuple_shapes()) { - TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); + TF_RETURN_IF_ERROR( + ValidateLayoutInShape(element_shape, allow_missing_layouts)); } return Status::OK(); } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { + if (allow_missing_layouts) { + return Status::OK(); + } return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape)); } @@ -190,8 +201,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (!ShapeUtil::IsArray(shape)) { - if (layout.minor_to_major_size() != 0 || - layout.padded_dimensions_size() != 0) { + if (layout.minor_to_major_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", PrimitiveType_Name(shape.element_type())); @@ -199,10 +209,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return Status::OK(); } - if (layout.format() == INVALID_FORMAT) { + if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { return InvalidArgument( - "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString(), shape.ShortDebugString()); + "Layout has an invalid format (%d) in layout {%s}, shape {%s}", + layout.format(), layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { @@ -230,28 +240,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } dimensions_in_layout[dim] = true; } - - if (layout.padded_dimensions_size() > 0) { - if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { - return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %d", - layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); - } - for (int i = 0; i < layout.padded_dimensions_size(); ++i) { - if (layout.padded_dimensions(i) < shape.dimensions(i)) { - return InvalidArgument( - "for dimension %d, dimension padding (%d) is smaller than " - "the dimension size (%d) of the shape", - i, layout.padded_dimensions(i), shape.dimensions(i)); - } - } - } - } - - if (layout.format() == SPARSE) { - if (!layout.padded_dimensions().empty()) { - return InvalidArgument("Sparse layout has padded dimensions"); - } } return Status::OK(); @@ -292,38 +280,6 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.minor_to_major().end(), std::greater()); } -/* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || - shape.layout().padded_dimensions_size() == 0) { - return false; - } - CHECK(IsDenseArray(shape)) << shape.ShortDebugString(); - CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { - return true; - } - } - return false; -} - -/* static */ absl::Span LayoutUtil::PaddedDimensions( - const Shape& shape) { - CHECK(IsDenseArray(shape)); - return AsInt64Slice(shape.layout().padded_dimensions()); -} - -/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, - int64 index) { - CHECK(IsDenseArray(shape)); - return shape.layout().padded_dimensions(index); -} - -/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { - CHECK(IsDenseArray(shape)); - return shape.layout().padding_value(); -} - /* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && shape.has_layout() && IsSparse(shape.layout()); @@ -502,14 +458,14 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { for (int64 minor_to_major : layout.minor_to_major()) { hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); } + hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); - for (int64 padded_dim : layout.padded_dimensions()) { - hash_value = Hash64Combine(hash_value, hash()(padded_dim)); + for (Tile tile : layout.tiles()) { + for (int64 tile_dim : tile.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(tile_dim)); + } } - - hash_value = - Hash64Combine(hash_value, hash()(layout.padding_value())); - hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + hash_value = Hash64Combine(hash_value, layout.element_size_in_bits()); return hash_value; } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index b78883c2d870043032306637730c4666665125a8..6c298e57252449ce3f1f9055436e918f2d9f17f1 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,6 +41,10 @@ class LayoutUtil { static Layout MakeLayoutFromMajorToMinor( absl::Span major_to_minor); + // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // dimensions. + static Layout MakeDescendingLayout(int64 rank); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); @@ -64,8 +69,11 @@ class LayoutUtil { // default. static void SetToDefaultLayout(ProgramShape* program_shape); - // Validates that the layout within the given shape is correct. - static Status ValidateLayoutInShape(const Shape& shape); + // Validates that the layout within the given shape is correct. The check + // is performed for all subshapes as well. If missing layouts are allowed + // the check does not fail on array shapes without layouts. + static Status ValidateLayoutInShape(const Shape& shape, + bool allow_missing_layouts = false); // Validates that the provided layout satisfies invariants for the given // shape. @@ -97,23 +105,6 @@ class LayoutUtil { // more minor, and so on until dimension N-1 which is the minor. static bool IsMonotonicWithDim0Major(const Layout& layout); - // Returns whether the layout of the given shape has padding (a - // padded_dimension value in Layout is greater than the corresponding - // dimension size). - static bool IsPadded(const Shape& shape); - - // Returns the padded_dimensions array for the given Shape. Requires that the - // shape is an array and has a dense layout. - static absl::Span PaddedDimensions(const Shape& shape); - - // Returns the given index of the padded_dimensions array for the given Shape. - // Requires that the shape is an array and has a dense layout. - static int64 PaddedDimension(const Shape& shape, int64 index); - - // Returns the padding_value for the given Shape. Requires that the shape is - // an array and has a dense layout. - static PaddingValue GetPaddingValue(const Shape& shape); - // Returns whether the given Shape is an array (i.e. not a tuple) and has a // sparse format layout. static bool IsSparseArray(const Shape& shape); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index f25dae6ff411133c74502039f441060f1329ffd4..12ce2d2d7c6fa8c590035f9ff2af50001ccf80d8 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -304,30 +304,6 @@ TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { shape.tuple_shapes(1).layout())); } -TEST_F(LayoutUtilTest, IsPadded) { - Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); - LayoutUtil::ClearLayout(&shape_without_layout); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout)); - - Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4}); - LayoutUtil::SetToDefaultLayout(&shape_with_layout); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout)); - - // Add padding equal to the dimension sizes. In this case the padding is a - // nop. - Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3); - shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4); - EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding)); - - Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4}); - shape_with_padding.mutable_layout()->add_padded_dimensions(2); - shape_with_padding.mutable_layout()->add_padded_dimensions(14); - shape_with_padding.mutable_layout()->add_padded_dimensions(42); - EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding)); -} - TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}), LayoutUtil::GetDefaultLayoutForR2())); @@ -352,5 +328,92 @@ TEST_F(LayoutUtilTest, StreamOut) { EXPECT_EQ(oss.str(), "{0,1,2}"); } +TEST_F(LayoutUtilTest, ValidateLayout_ValidArrayLayout) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_TRUE(status.ok()); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); +} + +TEST_F(LayoutUtilTest, ValidateLayout_InvalidArrayLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 2")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 2")); +} + +TEST_F(LayoutUtilTest, ValidateLayout_MissingArrayLayout) { + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + LayoutUtil::ClearLayout(&shape); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("shape f32[2,3] does not have a layout")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); +} + +TEST_F(LayoutUtilTest, ValidateLayout_TupleWithLayout) { + Shape shape = ShapeUtil::MakeTupleShape({}); + *shape.mutable_layout() = LayoutUtil::MakeLayout({0}); + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("tuple should not have a layout field")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("tuple should not have a layout field")); +} + +TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { + Shape sub_1_1_1 = ShapeUtil::MakeShape(F32, {1, 2}); + Shape sub_1_1 = ShapeUtil::MakeTupleShape({sub_1_1_1}); + Shape sub_1_2 = ShapeUtil::MakeShape(F32, {1, 2}); + LayoutUtil::ClearLayout(&sub_1_2); + Shape sub_1 = ShapeUtil::MakeTupleShape({sub_1_1, sub_1_2}); + Shape sub_2_1 = ShapeUtil::MakeShape(F32, {9}); + LayoutUtil::ClearLayout(&sub_2_1); + Shape sub_2 = ShapeUtil::MakeTupleShape({sub_2_1}); + Shape shape = ShapeUtil::MakeTupleShape({sub_1, sub_2}); + + auto status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/false); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("shape f32[1,2] does not have a layout")); + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_TRUE(status.ok()); + + // Add invalid layout on one of sub-shapes. + *shape.mutable_tuple_shapes(1)->mutable_tuple_shapes(0)->mutable_layout() = + LayoutUtil::MakeLayout({0, 2, 3}); + + status = + LayoutUtil::ValidateLayoutInShape(shape, /*allow_missing_layouts=*/true); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("layout minor_to_major field " + "contains 3 elements, but shape is rank 1")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD deleted file mode 100644 index 3e79129aafd234e5eab05d205f2017b54057795e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ /dev/null @@ -1,82 +0,0 @@ -# Legacy command-line flags for the XLA libraries. - -# Please do not add more flags to this package. - -# The XLA libraries were written in an environment that allowed command-line -# flags to be scattered freely throughout the libraries. This model, while -# initially convenient, leads to a proliferation in unused command-line flags -# in tests and binaries, and serious problems in servers, where one might wish -# parameters to be different in independent RPC calls to the same routine. -# -# Please don't add more flags. If you're a library author, pass options and -# parameters explicitly through the library's interface. - -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "parse_flags_from_env", - srcs = ["parse_flags_from_env.cc"], - hdrs = ["parse_flags_from_env.h"], - deps = - [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "parse_flags_from_env_test", - srcs = ["parse_flags_from_env_test.cc"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings:str_format", - ], -) - -cc_library( - name = "debug_options_flags", - srcs = [ - "debug_options_flags.cc", - "debug_options_parsers.h", - ], - hdrs = ["debug_options_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "@com_google_absl//absl/strings", - ], -) - -tf_cc_test( - name = "debug_options_parsers_test", - size = "small", - srcs = [ - "debug_options_parsers.h", - "debug_options_parsers_test.cc", - ], - deps = - [ - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], -) diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc deleted file mode 100644 index 2a4e49b05aa0d1eed2197095694cfc6aa8814983..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from an environtment variable, or a file named by the environment -// variable. - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried -static const char kWS[] = " \t\r\n"; // whitespace - -// The following struct represents an argv[]-style array, parsed -// from data gleaned from the environment. -// -// As usual, an anonymous namespace is advisable to avoid -// constructor/destructor collisions with other "private" types -// in the same named namespace. -namespace { -struct EnvArgv { - EnvArgv() : initialized(false), argc(0) {} - bool initialized; // whether the other fields have been set. - int argc; // elements used in argv[] - std::vector argv; // flag arguments parsed from environment string. - std::vector argv_save; // saved values from argv[] to avoid leaks -}; -} // anonymous namespace - -// Append the string s0[0, .., s0len-1] concatenated with s1[0, .., s1len-1] as -// a newly allocated nul-terminated string to the array *a. If s0==nullptr, a -// nullptr is appended without increasing a->argc. -static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, - size_t s1len, EnvArgv* a) { - if (s0 == nullptr) { - a->argv.push_back(nullptr); - a->argv_save.push_back(nullptr); - } else { - string s = string(s0, s0len) + string(s1, s1len); - char* str = strdup(s.c_str()); - a->argv.push_back(str); - a->argv_save.push_back(str); - a->argc++; - } -} - -// Like s.find_first_of(x, pos), but return s.size() when find_first_of() would -// return string::npos. This avoids if-statements elsewhere. -static size_t FindFirstOf(const string& s, const char* x, size_t pos) { - size_t result = s.find_first_of(x, pos); - return result == string::npos ? s.size() : result; -} - -// Like s.find_first_not_of(x, pos), but return s.size() when -// find_first_not_of() would return string::npos. This avoids if-statements -// elsewhere. -static size_t FindFirstNotOf(const string& s, const char* x, size_t pos) { - size_t result = s.find_first_not_of(x, pos); - return result == string::npos ? s.size() : result; -} - -// Given a string containing flags, parse them into the XLA command line flags. -// The parse is best effort, and gives up on the first syntax error. -static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { - size_t b = FindFirstNotOf(flag_str, kWS, 0); - while (b != flag_str.size() && flag_str[b] == '-') { - // b is the index of the start of a flag. - // Set e to the index just past the end of the flag. - size_t e = b; - while (e != flag_str.size() && isascii(flag_str[e]) && - (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { - e++; - } - if (e != flag_str.size() && flag_str[e] == '=' && - e + 1 != flag_str.size() && strchr("'\"", flag_str[e + 1]) != nullptr) { - // A flag of the form --flag="something in double or single quotes" - int c; - e++; // point just past '=' - size_t eflag = e; - char quote = flag_str[e]; - e++; // point just past quote - // Put in value the string with quotes removed. - string value; - for (; e != flag_str.size() && (c = flag_str[e]) != quote; e++) { - if (quote == '"' && c == '\\' && e + 1 != flag_str.size()) { - // Handle backslash in double quoted strings. They are literal in - // single-quoted strings. - e++; - c = flag_str[e]; - } - value += c; - } - if (e != flag_str.size()) { // skip final " or ' - e++; - } - AppendToEnvArgv(flag_str.data() + b, eflag - b, value.data(), - value.size(), a); - } else { // A flag without a quoted value. - e = FindFirstOf(flag_str, kWS, e); - AppendToEnvArgv(flag_str.data() + b, e - b, "", 0, a); - } - b = FindFirstNotOf(flag_str, kWS, e); - } -} - -// Call ParseArgvFromString(..., a) on a string derived from the setting of an -// environment variable kEnvVar, or a file it points to. -static void SetArgvFromEnv(EnvArgv* a) { - if (!a->initialized) { - static const char kDummyArgv[] = ""; - AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, - a); // dummy argv[0] - const char* env = getenv(kEnvVar); - if (env == nullptr || env[0] == '\0') { - // nothing - } else if (env[strspn(env, kWS)] == '-') { // flags in env var value - ParseArgvFromString(env, a); - } else { // assume it's a file name - FILE* fp = fopen(env, "r"); - if (fp != nullptr) { - string str; - char buf[512]; - int n; - while ((n = fread(buf, 1, sizeof(buf), fp)) > 0) { - str.append(buf, n); - } - fclose(fp); - ParseArgvFromString(str, a); - } - } - AppendToEnvArgv(nullptr, 0, nullptr, 0, a); // add trailing nullptr to *a. - a->initialized = true; - } -} - -// The simulated argv[] parsed from the environment. -static EnvArgv* env_argv; - -// Used to protect accesses to env_argv. -static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); - -// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized -// flags passed in from the environment. -bool ParseFlagsFromEnv(const std::vector& flag_list) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - SetArgvFromEnv(env_argv); // a no-op if already initialized - bool result = - tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); - env_argv_mu.unlock(); - return result; -} - -// Testing only. -// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv() -// will parse the environment variable (or the file it points to) anew, and set -// *pargc, and *pargv to point to the internal locations of the argc and argv -// constructed from the environment. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv) { - env_argv_mu.lock(); - if (env_argv == nullptr) { - env_argv = new EnvArgv; - } - if (!env_argv->argv_save.empty()) { - for (int i = 0; env_argv->argv_save[i] != nullptr; i++) { - free(env_argv->argv_save[i]); - } - } - env_argv->initialized = false; - env_argv->argc = 0; - env_argv->argv.clear(); - env_argv->argv_save.clear(); - env_argv_mu.unlock(); - *pargc = &env_argv->argc; - *pargv = &env_argv->argv; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h deleted file mode 100644 index b54482ad2ba2224c781861341a80ceb878ffd343..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ - -// This module exports ParseFlagsFromEnv(), which allows other modules to parse -// flags from the environtment variable TF_XLA_FLAGS, or (if the first -// non-whitespace in the variable value is not '-'), a file named by that -// environment variable. The accepted syntax is that flags arguments are of -// the form --flag=value or (for boolean flags) --flag, and are whitespace -// separated. The may be one of: -// - -// in which case the effective value is the string itself -// - in which case the effective value is the -// string with the single-quotes removed -// - in which case the effective value if the -// string with the double-quotes removed, and escaped sequences of -// replaced by . -// -// Flags values inconsistent with the type of the flag will be rejected by the -// flag parser. -// -// Examples: -// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" -// -// TF_XLA_FLAGS=/tmp/flagfile -// where /tmp/flagfile might contain -// --some_flag="This is a string containing a \" and a '." -// --another_flag=wombats - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet -// unrecognized flags passed in from the environment, and return its -// return value. -bool ParseFlagsFromEnv(const std::vector& flag_list); - -// Used only for testing. Not to be used by clients. -void ResetFlagsFromEnvForTesting(int** pargc, std::vector** pargv); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 656ce720a13d5c9622e9dc05ae04ddcac8cbeee5..8f480c1f1079b4e1a5be53958ebdf6e004ad9ebe 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,16 +22,17 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -62,6 +63,14 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be +// able to transparently access the raw 16-bit value contained within. +template +T GetRawValue(T val) { + return val; +} +uint16 GetRawValue(Eigen::half val) { return val.x; } + } // namespace LiteralBase::~LiteralBase() {} @@ -283,13 +292,17 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } - if (!LayoutUtil::HasLayout(proto.shape())) { + Shape shape(proto.shape()); + if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) { + return InvalidArgument("Literal shape cannot include OPAQUE sub-shape"); + } + if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("LiteralProto has no layout"); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - Literal literal(proto.shape()); + Literal literal(shape); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -1009,167 +1022,143 @@ void LiteralBase::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_layout, std::vector* pieces) { - const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - CHECK(LayoutUtil::HasLayout(literal.shape())); - CHECK(LayoutUtil::HasLayout(subshape)); +string ShapeToString(bool print_layout, const Shape& shape) { + return print_layout ? ShapeUtil::HumanStringWithLayout(shape) + : ShapeUtil::HumanString(shape); +} - auto shape_to_string = [print_layout](const Shape& shape) { - if (print_layout) { - return ShapeUtil::HumanStringWithLayout(shape); - } else { - return ShapeUtil::HumanString(shape); - } - }; +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces); - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" (\n"); - std::vector tuple_pieces; - for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { - ShapeIndex element_index = shape_index; - element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); +void TupleToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); + } + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); + pieces->push_back("\n)"); +} + +void SparseArrayToStringHelper(const LiteralBase& literal, + const Shape& subshape, bool print_layout, + std::vector* pieces) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); } - pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); - pieces->push_back("\n)"); - return; - } - - if (ShapeUtil::IsToken(subshape)) { - pieces->push_back("token"); - return; - } - - if (LayoutUtil::IsSparseArray(subshape)) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back("{"); - int64 rank = ShapeUtil::Rank(subshape); - int64 num_elements = literal.sparse_element_count(); - for (int64 i = 0; i < num_elements; ++i) { - if (i > 0) { - pieces->push_back(", "); - } - if (rank == 1) { - pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); - pieces->push_back(": "); - } else { - pieces->push_back("["); - pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); - pieces->push_back("]: "); - } - pieces->push_back(literal.GetSparseElementAsString(i)); + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); } - pieces->push_back("}"); - return; + pieces->push_back(literal.GetSparseElementAsString(i)); } + pieces->push_back("}"); +} - CHECK(LayoutUtil::IsDenseArray(subshape)); - - auto element_to_string = [&](absl::Span indices) -> string { - PrimitiveType element_type = subshape.element_type(); - if (element_type == PRED) { - // We display predicates in a densely packed form. - return literal.Get(indices, shape_index) ? "1" : "0"; - } - return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - literal.GetAsString(indices, shape_index); - }; +void DenseArrayToStringHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_layout, + std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + int64 rank = ShapeUtil::Rank(subshape); + + std::function dimensions, std::vector*)> + to_string_recursive = [&](absl::Span dimensions, + std::vector* accum_indices) { + // dimensions.size() decreases by 1 at each recursive call, + // and accum_indices->size() increases by 1. + // Their sum is equal to the rank of the tensor. + CHECK_EQ(rank, dimensions.size() + accum_indices->size()); + + auto brace_to_string = [&](string brace) -> string { + // Handle 1D tensor + if (rank == 1) { + return brace; + } + // Handle the innermost tensor of a 2D+ tensor. + if (dimensions.size() == 1 && brace == "{") { + return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " "); + } + if (dimensions.size() == 1 && brace == "}") { + return StrCat(dimensions[0] <= 1 ? "" : " ", brace); + } + // Handle the non-innermost tensors of a 2D+ tensor. + if (brace == "{") { + if (rank > 3 && !accum_indices->empty() && + accum_indices->size() < rank) { + int index = accum_indices->size() - 1; + int value = accum_indices->back(); + return StrCat(brace, " /*i", index, "=", value, "*/\n"); + } + return StrCat(brace, "\n"); + } + return StrCat("\n", brace); + }; - if (ShapeUtil::Rank(subshape) == 0) { - pieces->push_back(literal.GetAsString({}, shape_index)); - } else if (ShapeUtil::Rank(subshape) == 1) { - pieces->push_back("{"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(element_to_string({i0})); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 2) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(" { "); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(element_to_string({i0, i1})); - } - pieces->push_back(" "); - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 3) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(element_to_string({i0, i1, i2})); - } - pieces->push_back(" }"); - } - pieces->push_back(" }"); - } - pieces->push_back("\n}"); - } else if (ShapeUtil::Rank(subshape) == 4) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(" {"); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(element_to_string({i0, i1, i2, i3})); + if (dimensions.empty()) { + // Display predicates as 0s and 1s so that the string is more dense. + string elem; + if (subshape.element_type() == PRED && rank > 0) { + elem = literal.Get(*accum_indices, shape_index) ? "1" : "0"; + } else { + elem = literal.GetAsString(*accum_indices, shape_index); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); - } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); - } else if (ShapeUtil::Rank(subshape) == 5) { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {\n"); - for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); - for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); - for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); - for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { - pieces->push_back(" {"); - for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { - pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); + pieces->push_back(elem); + } else { + pieces->push_back(brace_to_string("{")); + for (int i = 0; i < dimensions[0]; ++i) { + std::vector cloned_indices(*accum_indices); + cloned_indices.push_back(i); + to_string_recursive(dimensions.subspan(1), &cloned_indices); + if (i < dimensions[0] - 1) { + pieces->push_back(","); + pieces->push_back(dimensions.size() > 1 ? "\n" : " "); } - pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" - : "},\n"); } - pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" - : " },\n"); + pieces->push_back(brace_to_string("}")); } - pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" - : " },\n"); - } - pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); - } - pieces->push_back("}"); + }; + + if (rank > 1) { + pieces->push_back(ShapeToString(print_layout, subshape)); + pieces->push_back(" "); + } + std::vector indices = {}; + std::vector dimensions(subshape.dimensions().begin(), + subshape.dimensions().end()); + to_string_recursive(dimensions, &indices); +} + +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); + if (ShapeUtil::IsTuple(subshape)) { + TupleToStringHelper(literal, shape_index, print_layout, pieces); + } else if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + } else if (LayoutUtil::IsSparseArray(subshape)) { + SparseArrayToStringHelper(literal, subshape, print_layout, pieces); } else { - pieces->push_back(shape_to_string(subshape)); - pieces->push_back(" {"); - literal.EachCellAsString( - [&](absl::Span indices, const string& value) { - pieces->push_back(" "); - pieces->push_back(value); - }); - pieces->push_back("}"); + CHECK(LayoutUtil::IsDenseArray(subshape)); + DenseArrayToStringHelper(literal, shape_index, print_layout, pieces); } } @@ -1226,16 +1215,32 @@ Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { } template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) && + !std::is_same::value), Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { - return tensorflow::bit_cast(src); + return absl::bit_cast(GetRawValue(src)); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); } +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) && + std::is_same::value), + Literal>::type +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly + // cast to unsigned short and then use raw_uint16_to_half. + auto converter = [](NativeSrcT src) { + return Eigen::half_impl::raw_uint16_to_half( + absl::bit_cast(GetRawValue(src))); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + // This template specialization is here to make the compiler happy. bit_cast has // a static check that the types are the same size. This specialization should // never be used because the source and destination types are checked for @@ -1432,10 +1437,14 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case U8: return EqualElementsInternal(other, &multi_index); + case S16: + return EqualElementsInternal(other, &multi_index); case S32: return EqualElementsInternal(other, &multi_index); case S64: return EqualElementsInternal(other, &multi_index); + case U16: + return EqualElementsInternal(other, &multi_index); case U32: return EqualElementsInternal(other, &multi_index); case U64: @@ -1504,6 +1513,11 @@ bool LiteralBase::IsAll(int8 value) const { return AllElementsEqualValue(piece.data(), value); } return false; + case U16: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; case U32: if (value >= 0) { return AllElementsEqualValue(piece.data(), value); @@ -1516,6 +1530,8 @@ bool LiteralBase::IsAll(int8 value) const { return false; case S8: return AllElementsEqualValue(piece.data(), value); + case S16: + return AllElementsEqualValue(piece.data(), value); case S32: return AllElementsEqualValue(piece.data(), value); case S64: @@ -1737,12 +1753,16 @@ bool LiteralBase::IsZero(absl::Span indices) const { switch (shape().element_type()) { case U8: return Get(indices) == 0; + case U16: + return Get(indices) == 0; case U32: return Get(indices) == 0; case U64: return Get(indices) == 0; case S8: return Get(indices) == 0; + case S16: + return Get(indices) == 0; case S32: return Get(indices) == 0; case S64: @@ -1775,7 +1795,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { - *proto->mutable_shape() = subshape(); + *proto->mutable_shape() = subshape().ToProto(); switch (subshape().element_type()) { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); @@ -1800,6 +1820,20 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case S64: CopyToRepeatedField(proto->mutable_s64s(), data()); break; + case U16: + *proto->mutable_u16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_u16s()); + } + break; + case S16: + *proto->mutable_s16s() = string( + reinterpret_cast(data().data()), size_bytes()); + if (!kLittleEndian) { + ConvertEndianShort(proto->mutable_s16s()); + } + break; case F16: *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); @@ -1867,8 +1901,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in // MutableLiteralBase::CreateFromProto. TF_RET_CHECK(proto.has_shape()); - TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); - TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + Shape shape(proto.shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(shape)); + TF_RET_CHECK(ShapeUtil::Equal(shape, subshape())); if (LayoutUtil::IsSparseArray(subshape())) { // Compute the number of elements (indices) in the sparse shape and reserve @@ -1914,6 +1949,22 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case U64: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; + case S16: { + const string& s(proto.s16s()); + TF_RET_CHECK(data().size() * sizeof(int16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; + case U16: { + const string& s(proto.u16s()); + TF_RET_CHECK(data().size() * sizeof(uint16_t) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); + } + } break; case F16: { const string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); @@ -1992,7 +2043,7 @@ string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); - return string(tensorflow::bit_cast(data().data()), + return string(absl::bit_cast(data().data()), ShapeUtil::ElementsIn(shape())); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 3cd3541fe1596600b4f0b43e3011e1f0322ac8fe..fa9a71af4ceb998a7a289443cbef70eb52cb1a11 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -301,7 +301,7 @@ class LiteralBase { // // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero - // initialization, then reinitialization. Conside if a call to + // initialization, then reinitialization. Consider if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static Literal CreateFromShape(const Shape& shape); @@ -979,9 +979,8 @@ inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); CHECK_EQ(shape().element_type(), primitive_util::NativeToPrimitiveType()); - for (int64 i = 0; i < values.size(); ++i) { - Set({i}, values[i]); - } + auto data_span = data(); + std::copy(values.begin(), values.end(), data_span.begin()); } template diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 3d8725ed7051cafc97987f25a96004fa876dfdd3..b044f0ad73f13a0599e77f1f43888bc974e31f73 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/env.h" using absl::StrAppend; @@ -34,14 +34,22 @@ namespace xla { namespace literal_comparison { namespace { +// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be +// able to transparently access the raw 16-bit value contained within. +template +T GetRawValue(T val) { + return val; +} +uint16 GetRawValue(Eigen::half val) { return val.x; } + // Helper function for comparing a floating point type, FloatT, bitwise equal // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, absl::Span multi_index) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); + auto ulhs = absl::bit_cast(GetRawValue(lhs)); + auto urhs = absl::bit_cast(GetRawValue(rhs)); auto lhs_double = static_cast(lhs); auto rhs_double = static_cast(rhs); if (ulhs != urhs) { @@ -133,8 +141,10 @@ int64 RecursiveElementCount(const Shape& shape) { total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); } return total; - } else { + } else if (ShapeUtil::IsArray(shape)) { return ShapeUtil::ElementsIn(shape); + } else { + return 0; } } diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index dd5b54e4c99998f676419cf98a3da16593338829..49363ad802ddb9520f89b53257216bc7ddaf8ff5 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -133,7 +133,7 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec.ToString()); + EXPECT_EQ("{1, 0, 1}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -150,12 +150,58 @@ TEST_F(LiteralUtilTest, R3ToString) { const auto literal = LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { -{ { 1 }, - { 2 } }, -{ { 3 }, - { 4 } }, -{ { 5 }, - { 6 } } +{ + {1}, + {2} +}, +{ + {3}, + {4} +}, +{ + {5}, + {6} +} +})"; + EXPECT_EQ(expected, literal.ToString()); +} + +TEST_F(LiteralUtilTest, R6ToString) { + const auto literal = + LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2}); + const string expected = R"(s32[2,2,1,1,1,2] { +{ /*i0=0*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +}, +{ /*i0=1*/ +{ /*i1=0*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +}, +{ /*i1=1*/ +{ /*i2=0*/ +{ /*i3=0*/ + { 0, 0 } +} +} +} +} })"; EXPECT_EQ(expected, literal.ToString()); } @@ -190,12 +236,16 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[2,3,2] { -{ { 1, 2 }, +{ + { 1, 2 }, { 3, 4 }, - { 5, 6 } }, -{ { 7, 8 }, + { 5, 6 } +}, +{ + { 7, 8 }, { 9, 10 }, - { 11, 12 } } + { 11, 12 } +} })"; EXPECT_EQ(expected, result); } @@ -247,18 +297,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - }, - { /*i1=1*/ - {1, 2}, - {1001, 1002}, - {2001, 2002} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +}, +{ /*i1=1*/ + { 1, 2 }, + { 1001, 1002 }, + { 2001, 2002 } +} +} })"; EXPECT_EQ(expected, result); } @@ -268,30 +318,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { - { /*i0=0*/ - { /*i1=0*/ - {1, 2, 3}, - {4, 5, 6}, - {7, 8, 9} - }, - { /*i1=1*/ - {11, 12, 13}, - {14, 15, 16}, - {17, 18, 19} - } - }, - { /*i0=1*/ - { /*i1=0*/ - {101, 102, 103}, - {104, 105, 106}, - {107, 108, 109} - }, - { /*i1=1*/ - {201, 202, 203}, - {204, 205, 206}, - {207, 208, 209} - } - } +{ /*i0=0*/ +{ /*i1=0*/ + { 1, 2, 3 }, + { 4, 5, 6 }, + { 7, 8, 9 } +}, +{ /*i1=1*/ + { 11, 12, 13 }, + { 14, 15, 16 }, + { 17, 18, 19 } +} +}, +{ /*i0=1*/ +{ /*i1=0*/ + { 101, 102, 103 }, + { 104, 105, 106 }, + { 107, 108, 109 } +}, +{ /*i1=1*/ + { 201, 202, 203 }, + { 204, 205, 206 }, + { 207, 208, 209 } +} +} })"; EXPECT_EQ(expected, result); } @@ -1312,11 +1362,10 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { TEST_F(LiteralUtilTest, BitcastConvert) { auto original = LiteralUtil::CreateR1( - {tensorflow::bit_cast(2.5f), - tensorflow::bit_cast(-42.25f), - tensorflow::bit_cast(100.f), 0xbeef}); + {absl::bit_cast(2.5f), absl::bit_cast(-42.25f), + absl::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( - {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); + {2.5f, -42.25f, 100.0f, absl::bit_cast(0xbeef)}); TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } @@ -1328,13 +1377,26 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { absl::StrContains(status.error_message(), "bit widths are different")); } +// Sets the layout of the given ShapeProto to the default. +void SetDefaultLayoutOnProto(ShapeProto* shape_proto) { + CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type())); + shape_proto->mutable_layout()->set_format(DENSE); + auto* minor_to_major = + shape_proto->mutable_layout()->mutable_minor_to_major(); + minor_to_major->Resize(shape_proto->dimensions_size(), 0); + const int64 size = minor_to_major->size(); + for (int64 i = 0; i < size; ++i) { + minor_to_major->Set(i, size - 1 - i); + } +} + TEST_F(LiteralUtilTest, CopyFromProto_Bool) { LiteralProto p; p.mutable_shape()->set_element_type(PRED); for (int len = 0; len < 25; ++len) { p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(len); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_preds(); for (int i = 0; i < len; ++i) { p.add_preds((i % 2) == (len % 2)); @@ -1360,7 +1422,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { EXPECT_EQ(4, m.data().size()); LiteralProto p = m.ToProto(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); + EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape()))); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); EXPECT_EQ(d[0], 0); @@ -1383,7 +1445,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); - LayoutUtil::SetToDefaultLayout(p.mutable_shape()); + SetDefaultLayoutOnProto(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); @@ -1395,6 +1457,28 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { EXPECT_EQ(h1, r[3]); } +TEST_F(LiteralUtilTest, CopyFromProto_u16) { + uint16 u1(0xabcd); + uint16 u2(0x1234); + + const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12, + 0x34, 0x12, 0xcd, 0xab}; + LiteralProto p; + p.mutable_shape()->set_element_type(U16); + p.mutable_shape()->clear_dimensions(); + p.mutable_shape()->add_dimensions(4); + SetDefaultLayoutOnProto(p.mutable_shape()); + p.clear_u16s(); + p.set_u16s(uint16_vals, 8); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); + ASSERT_EQ(4, r.size()); + EXPECT_EQ(u1, r[0]); + EXPECT_EQ(u2, r[1]); + EXPECT_EQ(u2, r[2]); + EXPECT_EQ(u1, r[3]); +} + TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1516,9 +1600,9 @@ TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nested_tuple = LiteralUtil::MakeTuple( {&tuple_elements[0], &tuple_elements[1], &nil_literal}); - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); std::vector elements = nested_tuple.DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1569,7 +1653,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); for (const Literal& element : elements) { - EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape())); } } @@ -1685,7 +1769,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { TEST_F(LiteralUtilTest, InvalidProtoNoValues) { // Proto contains a shape, but no values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), @@ -1706,7 +1790,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { // Proto contains values in wrong container. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); proto.add_preds(false); proto.add_preds(true); proto.add_preds(false); @@ -1719,7 +1803,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { // Proto contains too few values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto(); proto.add_f32s(1.0); proto.add_f32s(2.0); proto.add_f32s(3.0); @@ -1732,7 +1816,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { // Proto contains too many values. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); + *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto(); proto.add_s32s(42); proto.add_s32s(-10); proto.add_s32s(100); @@ -1745,8 +1829,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { // Proto shape missing layout. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); - LayoutUtil::ClearLayout(proto.mutable_shape()); + *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto(); + proto.mutable_shape()->clear_layout(); proto.add_preds(true); proto.add_preds(false); proto.add_preds(true); @@ -1759,11 +1843,13 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { // Proto has the too few tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); @@ -1775,19 +1861,21 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { // Proto has the too many tuple elements. LiteralProto proto; - *proto.mutable_shape() = ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + *proto.mutable_shape() = + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}) + .ToProto(); LiteralProto* element0 = proto.add_tuple_literals(); *element0->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto(); element0->add_preds(false); element0->add_preds(true); LiteralProto* element1 = proto.add_tuple_literals(); *element1->mutable_shape() = - ShapeUtil::GetTupleElementShape(proto.shape(), 1); + ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto(); element1->add_f32s(42.0); LiteralProto* element2 = proto.add_tuple_literals(); - *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto(); element2->add_f32s(123.0); Status status = Literal::CreateFromProto(proto).status(); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 0cb1ae35f4ad31f091063d78ed32c1463be8ee0a..bb5e5e61000d0aca6ab052ac87d2fbcd96e55f70 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/parse_flags_from_env.cc b/tensorflow/compiler/xla/parse_flags_from_env.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b568888d14f21c1330556d017eafba6c8dd2228 --- /dev/null +++ b/tensorflow/compiler/xla/parse_flags_from_env.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or a file named by the +// environment variable. + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { + +static const char kWS[] = " \t\r\n"; // whitespace + +// The following struct represents an argv[]-style array, parsed +// from data gleaned from the environment. +// +// As usual, an anonymous namespace is advisable to avoid +// constructor/destructor collisions with other "private" types +// in the same named namespace. +namespace { + +// Functor which deletes objects by calling `free`. Necessary to free strdup'ed +// strings created by AppendToEnvArgv. +struct FreeDeleter { + void operator()(char* ptr) { free(ptr); } +}; + +struct EnvArgv { + EnvArgv() : initialized(false), argc(0) {} + bool initialized; // whether the other fields have been set. + int argc; // elements used in argv[] + std::vector argv; // flag arguments parsed from environment string. + // saved values from argv[] to avoid leaks + std::vector> argv_save; +}; +} // anonymous namespace + +// Append the string s0[0, .., s0len-1] concatenated with s1[0, .., s1len-1] as +// a newly allocated nul-terminated string to the array *a. If s0==nullptr, a +// nullptr is appended without increasing a->argc. +static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1, + size_t s1len, EnvArgv* a) { + if (s0 == nullptr) { + a->argv.push_back(nullptr); + a->argv_save.push_back(nullptr); + } else { + string s = string(s0, s0len) + string(s1, s1len); + char* str = strdup(s.c_str()); + a->argv.push_back(str); + a->argv_save.emplace_back(str); + a->argc++; + } +} + +// Like s.find_first_of(x, pos), but return s.size() when find_first_of() would +// return string::npos. This avoids if-statements elsewhere. +static size_t FindFirstOf(const string& s, const char* x, size_t pos) { + size_t result = s.find_first_of(x, pos); + return result == string::npos ? s.size() : result; +} + +// Like s.find_first_not_of(x, pos), but return s.size() when +// find_first_not_of() would return string::npos. This avoids if-statements +// elsewhere. +static size_t FindFirstNotOf(const string& s, const char* x, size_t pos) { + size_t result = s.find_first_not_of(x, pos); + return result == string::npos ? s.size() : result; +} + +// Given a string containing flags, parse them into the XLA command line flags. +// The parse is best effort, and gives up on the first syntax error. +static void ParseArgvFromString(const string& flag_str, EnvArgv* a) { + size_t b = FindFirstNotOf(flag_str, kWS, 0); + while (b != flag_str.size() && flag_str[b] == '-') { + // b is the index of the start of a flag. + // Set e to the index just past the end of the flag. + size_t e = b; + while (e != flag_str.size() && isascii(flag_str[e]) && + (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) { + e++; + } + if (e != flag_str.size() && flag_str[e] == '=' && + e + 1 != flag_str.size() && strchr("'\"", flag_str[e + 1]) != nullptr) { + // A flag of the form --flag="something in double or single quotes" + int c; + e++; // point just past '=' + size_t eflag = e; + char quote = flag_str[e]; + e++; // point just past quote + // Put in value the string with quotes removed. + string value; + for (; e != flag_str.size() && (c = flag_str[e]) != quote; e++) { + if (quote == '"' && c == '\\' && e + 1 != flag_str.size()) { + // Handle backslash in double quoted strings. They are literal in + // single-quoted strings. + e++; + c = flag_str[e]; + } + value += c; + } + if (e != flag_str.size()) { // skip final " or ' + e++; + } + AppendToEnvArgv(flag_str.data() + b, eflag - b, value.data(), + value.size(), a); + } else { // A flag without a quoted value. + e = FindFirstOf(flag_str, kWS, e); + AppendToEnvArgv(flag_str.data() + b, e - b, "", 0, a); + } + b = FindFirstNotOf(flag_str, kWS, e); + } +} + +// Call ParseArgvFromString(..., a) on a string derived from the setting of the +// environment variable `envvar`, or a file it points to. +static void SetArgvFromEnv(absl::string_view envvar, EnvArgv* a) { + if (!a->initialized) { + static const char kDummyArgv[] = ""; + AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0, + a); // dummy argv[0] + const char* env = getenv(string(envvar).c_str()); + if (env == nullptr || env[0] == '\0') { + // nothing + } else if (env[strspn(env, kWS)] == '-') { // flags in env var value + ParseArgvFromString(env, a); + } else { // assume it's a file name + FILE* fp = fopen(env, "r"); + if (fp != nullptr) { + string str; + char buf[512]; + int n; + while ((n = fread(buf, 1, sizeof(buf), fp)) > 0) { + str.append(buf, n); + } + fclose(fp); + ParseArgvFromString(str, a); + } + } + AppendToEnvArgv(nullptr, 0, nullptr, 0, a); // add trailing nullptr to *a. + a->initialized = true; + } +} + +// The simulated argv[] parsed from the environment, one for each different +// environment variable we've seen. +static std::unordered_map& EnvArgvs() { + static auto* env_argvs = new std::unordered_map(); + return *env_argvs; +} + +// Used to protect accesses to env_argvs. +static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED); + +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list) { + tensorflow::mutex_lock lock(env_argv_mu); + auto* env_argv = &EnvArgvs()[string(envvar)]; + SetArgvFromEnv(envvar, env_argv); // a no-op if already initialized + bool result = + tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list); + + // There's always at least one unparsed argc, namely the fake argv[0]. + if (result && env_argv->argc != 1) { + // Skip the first argv, which is the fake argv[0]. + auto unknown_flags = absl::MakeSpan(env_argv->argv); + unknown_flags.remove_prefix(1); + + // Some flags are set on XLA_FLAGS, others on TF_XLA_FLAGS. If we find an + // unrecognized flag, suggest the alternative. + string alternate_envvar; + if (envvar == "TF_XLA_FLAGS") { + alternate_envvar = "XLA_FLAGS"; + } else if (envvar == "XLA_FLAGS") { + alternate_envvar = "TF_XLA_FLAGS"; + } + string did_you_mean; + if (!alternate_envvar.empty()) { + did_you_mean = absl::StrFormat( + "\nPerhaps you meant to specify these on the %s envvar?", + alternate_envvar); + } + + LOG(FATAL) << "Unknown flag" << (unknown_flags.size() > 1 ? "s" : "") + << " in " << envvar << ": " << absl::StrJoin(unknown_flags, " ") + << did_you_mean; + return false; + } + return result; +} + +// Testing only. +// +// Resets the env_argv struct so that subsequent calls to +// ParseFlagsFromEnvAndDieIfUnknown() will parse the environment variable (or +// the file it points to) anew, and set *pargc, and *pargv to point to the +// internal locations of the argc and argv constructed from the environment. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv) { + tensorflow::mutex_lock lock(env_argv_mu); + EnvArgvs().erase(string(envvar)); + auto& env_argv = EnvArgvs()[string(envvar)]; + *pargc = &env_argv.argc; + *pargv = &env_argv.argv; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/parse_flags_from_env.h b/tensorflow/compiler/xla/parse_flags_from_env.h new file mode 100644 index 0000000000000000000000000000000000000000..76940a4299ac50138222333ff250a264cc941288 --- /dev/null +++ b/tensorflow/compiler/xla/parse_flags_from_env.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ +#define TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ + +// This module exports ParseFlagsFromEnvAndDieIfUnknown(), which allows other +// modules to parse flags from an environtment variable, or (if the first +// non-whitespace in the variable value is not '-'), a file named by that +// environment variable. +// +// The accepted syntax is that flags arguments are of the form --flag=value or +// (for boolean flags) --flag, and are whitespace separated. The may be +// one of: +// +// - +// in which case the effective value is the string itself +// - in which case the effective value is the +// string with the single-quotes removed +// - in which case the effective value if the +// string with the double-quotes removed, and escaped sequences of +// replaced by . +// +// Flags values inconsistent with the type of the flag will be rejected by the +// flag parser. +// +// Examples: +// +// - TF_XLA_FLAGS="--foo=bar --wombat='value with a space'" +// - TF_XLA_FLAGS=/tmp/flagfile +// +// where /tmp/flagfile might contain +// +// --some_flag="This is a string containing a \" and a '." +// --another_flag=wombats + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { + +// Calls tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet +// unrecognized flags passed in the environment variable `envvar`, and returns +// its return value. +// +// Raises a fatal error if any flags in `envvar` were not recognized. +bool ParseFlagsFromEnvAndDieIfUnknown( + absl::string_view envvar, const std::vector& flag_list); + +// Used only for testing. Not to be used by clients. +void ResetFlagsFromEnvForTesting(absl::string_view envvar, int** pargc, + std::vector** pargv); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PARSE_FLAGS_FROM_ENV_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/parse_flags_from_env_test.cc similarity index 89% rename from tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc rename to tensorflow/compiler/xla/parse_flags_from_env_test.cc index 138c0c852e2bb0527d171f25b4d96cedc5671516..3465552ebbf52140fb954b247d99d3c6afe7fcde 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/parse_flags_from_env_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Test for parse_flags_from_env.cc -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/parse_flags_from_env.h" #include #include @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace xla { -namespace legacy_flags { // Test that XLA flags can be set from the environment. // Failure messages are accompanied by the text in msg[]. @@ -38,20 +37,7 @@ static void TestParseFlagsFromEnv(const char* msg) { // Initialize module under test. int* pargc; std::vector* pargv; - ResetFlagsFromEnvForTesting(&pargc, &pargv); - - // Ensure that environment variable can be parsed when - // no flags are expected. - std::vector empty_flag_list; - bool parsed_ok = ParseFlagsFromEnv(empty_flag_list); - CHECK(parsed_ok) << msg; - const std::vector& argv_first = *pargv; - CHECK_NE(argv_first[0], nullptr) << msg; - int i = 0; - while (argv_first[i] != nullptr) { - i++; - } - CHECK_EQ(i, *pargc) << msg; + ResetFlagsFromEnvForTesting("TF_XLA_FLAGS", &pargc, &pargv); // Check that actual flags can be parsed. bool simple = false; @@ -66,7 +52,7 @@ static void TestParseFlagsFromEnv(const char* msg) { tensorflow::Flag("single_quoted", &single_quoted, ""), tensorflow::Flag("double_quoted", &double_quoted, ""), }; - parsed_ok = ParseFlagsFromEnv(flag_list); + bool parsed_ok = ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); CHECK_EQ(*pargc, 1) << msg; const std::vector& argv_second = *pargv; CHECK_NE(argv_second[0], nullptr) << msg; @@ -159,12 +145,11 @@ TEST(ParseFlagsFromEnv, EnvAndFlag) { } } -} // namespace legacy_flags } // namespace xla int main(int argc, char* argv[]) { // Save name of binary so that it may invoke itself. - xla::legacy_flags::binary_name = argv[0]; + xla::binary_name = argv[0]; bool recursing = false; xla::int32 int_flag = 1; const std::vector flag_list = { @@ -173,7 +158,8 @@ int main(int argc, char* argv[]) { tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list); + bool parse_ok = + xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", flag_list); if (!parse_ok) { LOG(QFATAL) << "can't parse from environment\n" << usage; } diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index b507a2ef79f1d7e9ae632744675dddf574490805..ac342bf40fbc0052acbb09a346b9d062561ed06b 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -40,16 +40,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, namespace { -string SanitizeFilename(const string& file_name) { - string safe_file_name = file_name; - for (char& c : safe_file_name) { - if (c == '/' || c == '\\') { - c = '_'; - } - } - return safe_file_name; -} - std::pair>*> GetDirectoryExpanders() { static auto* mutex = new tensorflow::mutex; diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index f0d84646b9f01ad3ad209073f13b7b3ec21635d1..63ac1c6649210cbae9e238a74e0a45fb8ee4da63 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") py_library( name = "xla_client", @@ -50,7 +51,14 @@ cc_library( srcs = ["local_computation_builder.cc"], hdrs = ["local_computation_builder.h"], deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", @@ -59,8 +67,11 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:framework_lite", + "//tensorflow/compiler/xrt:xrt_proto", + "//tensorflow/compiler/xrt/cc:xrt_ops", + "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", @@ -72,6 +83,7 @@ tf_py_wrap_cc( srcs = ["xla.i"], swig_includes = [ "local_computation_builder.i", + "//tensorflow/python:platform/base.i", ], deps = [ ":local_computation_builder", @@ -80,5 +92,7 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:cpu_plugin", - ], + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service:gpu_plugin", + ]), ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cd5fd330298fb0ff158e232dac121f8ffb271218..6e2ee866321a070d55a7221c7c68024ceaa93448 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,16 +14,42 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" + +#include +#include +#include + #include "absl/memory/memory.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h" +#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" namespace xla { namespace swig { +// TODO(b/118641336): Factor out XRT parts into a small c++ library of their +// own. + // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of // device handles instead of needing to set the number of replicas at XLA // service initialization time. @@ -31,6 +57,12 @@ tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; +string* GetPlatformNameString() { + static string* platform_name_string PT_GUARDED_BY(g_local_client_mutex) = + new string("Host"); + return platform_name_string; +} + Status InitializeReplicaCount(int replica_count) { if (replica_count < 1) { return InvalidArgument("Replica count must be >= 1; got %d.", @@ -47,17 +79,33 @@ Status InitializeReplicaCount(int replica_count) { return Status::OK(); } +Status InitializePlatformName(const string& platform_name) { + string* g_platform_name = GetPlatformNameString(); + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the platform name to %s, but a local XLA service was " + "previously created with a platform name of %s.", + platform_name, *g_platform_name); + } + TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status()); + *g_platform_name = platform_name; + return Status::OK(); +} + int GetReplicaCount() { tensorflow::mutex_lock lock(g_local_client_mutex); return g_replica_count; } LocalClient* GetOrCreateLocalClient() { + string* platform_name = GetPlatformNameString(); tensorflow::mutex_lock lock(g_local_client_mutex); if (g_local_client != nullptr) { return g_local_client; } LocalClientOptions options; + options.set_platform(PlatformUtil::GetPlatform(*platform_name).ValueOrDie()); options.set_number_of_replicas(g_replica_count); g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); CHECK(g_local_client != nullptr); @@ -91,6 +139,33 @@ StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, return client->TransferFromOutfeedLocal(shape, device_ordinal); } +static StatusOr ToBuffer(LocalClient* client, + int device_ordinal, + const Literal& arg) { + return client->LiteralToShapedBuffer(arg, device_ordinal, + client->backend().memory_allocator()); +} + +/* static */ +StatusOr LocalShapedBuffer::FromLiteral( + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number) { + LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + VLOG(1) << "Creating shaped buffer from literal on replica/ordinal: " + << replica_number << "/" << device_ordinal; + StatusOr buf = [&] { + if (shape_with_layout) { + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, device_ordinal, relaid); + } + return ToBuffer(client, device_ordinal, argument); + }(); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); +} + LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) : shaped_buffer_(std::move(shaped_buffer)) {} @@ -100,11 +175,20 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } +const Shape& LocalShapedBuffer::shape() const { + return shaped_buffer()->on_device_shape(); +} + +StatusOr LocalShapedBuffer::ToLiteral() const { + LocalClient* client = GetOrCreateLocalClient(); + return client->ShapedBufferToLiteral(*shaped_buffer()); +} + LocalShapedBufferTuple::LocalShapedBufferTuple( std::vector elements) : elements_(std::move(elements)) { for (auto* element : elements_) { - DCHECK(element != nullptr); + CHECK(element != nullptr); } } @@ -126,157 +210,316 @@ StatusOr LocalShapedBufferTuple::Release(int i) { return element; } -int LocalShapedBufferTuple::size() const { return elements_.size(); } +int64 LocalShapedBufferTuple::size() const { return elements_.size(); } + +XrtAllocation::XrtAllocation(int64 handle, Shape shape, + const string& session_target) + : handle_(handle), shape_(shape), session_target_(session_target) {} + +XrtAllocation::~XrtAllocation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseAllocationHandle(root, allocation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; + } -static StatusOr ToBuffer(LocalClient* client, - int device_ordinal, - const Literal& arg) { - return client->LiteralToShapedBuffer(arg, device_ordinal, - client->backend().memory_allocator()); + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } } /* static */ -StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout) { - LocalClient* client = GetOrCreateLocalClient(); - StatusOr buf = [&] { - if (shape_with_layout) { - Literal relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, relaid); +StatusOr XrtAllocation::FromLiteral( + const Literal& argument, const string& session_target) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = argument.ToProto(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto literal_string = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto literal_handle = tensorflow::ops::XRTAllocate(root, literal_string); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({literal_string, alloc.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {literal_handle}, &outputs)); + + int64 handle = outputs[0].scalar()(); + return new XrtAllocation(handle, argument.shape(), session_target); +} + +const int64 XrtAllocation::handle() const { return handle_; } + +const Shape& XrtAllocation::shape() const { return shape_; } + +StatusOr XrtAllocation::ToLiteral() const { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto allocation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto read_literal = tensorflow::ops::XRTReadLiteral(root, allocation_handle); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({allocation_handle, handle()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {read_literal}, &outputs)); + + xla::LiteralProto response; + TF_RET_CHECK(response.ParseFromString(outputs[0].scalar()())); + return Literal::CreateFromProto(response); +} + +XrtAllocationTuple::XrtAllocationTuple(std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + CHECK(element != nullptr); + } +} + +XrtAllocationTuple::~XrtAllocationTuple() { + for (XrtAllocation* element : elements_) { + if (element != nullptr) { + delete element; } - return ToBuffer(client, /*device_ordinal=*/0, argument); - }(); - TF_RETURN_IF_ERROR(buf.status()); - return new LocalShapedBuffer(std::move(buf).ValueOrDie()); + } } -StatusOr LocalShapedBuffer::ToLiteral() const { - LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()); +StatusOr XrtAllocationTuple::Release(int i) { + XrtAllocation* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; } +int64 XrtAllocationTuple::size() const { return elements_.size(); } + CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr CompiledLocalComputation::Execute( - const std::vector& arguments, - const std::vector>& shapes_with_layout) { +StatusOr CompiledLocalComputation::Execute( + absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); + StatusOr device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0); + StatusOr result_buffer_status; + if (!device_ordinal_status.ok()) { + result_buffer_status = device_ordinal_status.status(); + } else { + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica 0 mapped to device ordinal for execution: " + << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(1, /*computation_count=*/1) + .ConsumeValueOrDie(); - VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); - // Each replica populates a StatusOr result, but only replica zero actually - // retrieves its literal value. - std::vector> results(GetReplicaCount()); - { + result_buffer_status = executable_->Run(argument_buffers, options); + } + + if (!result_buffer_status.ok()) { + return InternalError( + "Failed running replica 0 (other replicas may have failed as well): " + "%s.", + result_buffer_status.status().ToString()); + } + return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie()); +} + +StatusOr CompiledLocalComputation::ExecutePerReplica( + absl::Span> argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + const int num_replicas = GetReplicaCount(); + + if (argument_handles.size() != num_replicas) { + return InvalidArgument( + "Attempted to execute with %d replicas when replica count is %d", + argument_handles.size(), num_replicas); + } + + VLOG(1) << "Executing with " << num_replicas << " replicas."; + + // Each replica populates a StatusOr result, but only the output value of + // replica zero is returned. + std::vector> results(num_replicas); + auto execute = [this, client, num_replicas, &argument_handles, + &results](int replica) { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " << device_ordinal; + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles[replica].size()); + for (auto& handle : argument_handles[replica]) { + argument_buffers.push_back(handle->shaped_buffer()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(num_replicas, /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr result_buffer_status = + executable_->Run(argument_buffers, options); + + results[replica] = std::move(result_buffer_status); + }; + + if (num_replicas == 1) { + // Fast-path if there is only one replica — run the computation on the + // current thread. + execute(0); + } else { + // TODO(phawkins): don't recreate the threadpool for each execution. tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", - GetReplicaCount()); - - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - pool.Schedule( - [this, client, replica, &arguments, &shapes_with_layout, &results] { - StatusOr device_ordinal_status = - client->ReplicaNumberToDeviceOrdinal(replica); - if (!device_ordinal_status.ok()) { - results[replica] = device_ordinal_status.status(); - return; - } - const int device_ordinal = device_ordinal_status.ValueOrDie(); - VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " - << device_ordinal; - - // Transfer arguments in - std::vector scoped_buffers; - scoped_buffers.reserve(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Literal& argument = arguments[i]; - const absl::optional& shape_with_layout = - shapes_with_layout[i]; - - StatusOr pushed; - if (shape_with_layout) { - Literal relaid = argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, relaid); - } else { - pushed = ToBuffer(client, device_ordinal, argument); - } - if (!pushed.ok()) { - results[replica] = pushed.status(); - return; - } - - scoped_buffers.push_back(std::move(pushed).ValueOrDie()); - } - - // Execute - std::vector argument_buffers; - argument_buffers.reserve(scoped_buffers.size()); - for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(&buffer); - } - - DeviceAssignment device_assignment = - client->backend() - .computation_placer() - ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) - .ConsumeValueOrDie(); - - ExecutableRunOptions options; - options.set_device_ordinal(device_ordinal); - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - options.set_device_assignment(&device_assignment); - StatusOr result_buffer_status = - executable_->Run(argument_buffers, options); - if (!result_buffer_status.ok()) { - results[replica] = result_buffer_status.status(); - return; - } - - // Transfer result out - results[replica] = client->ShapedBufferToLiteral( - std::move(result_buffer_status).ValueOrDie()); - }); + num_replicas - 1); + + for (int replica = 0; replica < num_replicas - 1; ++replica) { + pool.Schedule([&execute, replica] { execute(replica); }); } + execute(num_replicas - 1); } - for (int replica = 0; replica < GetReplicaCount(); ++replica) { - const auto& statusor = results[replica]; + std::vector wrapped_results(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) { + auto& statusor = results[replica]; if (!statusor.ok()) { return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", replica, statusor.status().ToString()); } + wrapped_results[replica] = + new LocalShapedBuffer(std::move(statusor).ValueOrDie()); } - return std::move(results[0]); + return new LocalShapedBufferTuple(std::move(wrapped_results)); } -LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( - absl::Span argument_handles) { - LocalClient* client = GetOrCreateLocalClient(); +static StatusOr GetReturnValueShape(const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation.GetProgramShape()); + return std::move(*program_shape.mutable_result()); +} - std::vector argument_buffers; - argument_buffers.reserve(argument_handles.size()); - for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer()); +CompiledXrtComputation::CompiledXrtComputation( + const ProgramShape& program_shape, int64 handle, + const string& session_target) + : program_shape_(program_shape), + handle_(handle), + session_target_(session_target) {} + +CompiledXrtComputation::~CompiledXrtComputation() { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto release = + tensorflow::ops::XRTReleaseCompilationHandle(root, computation_handle); + if (!root.status().ok()) { + LOG(ERROR) << root.status(); + return; } - // Execute - ExecutableRunOptions options; - options.set_allocator(client->backend().memory_allocator()); - options.set_intra_op_thread_pool( - client->backend().eigen_intra_op_thread_pool_device()); - ScopedShapedBuffer result_buffer = - executable_->Run(argument_buffers, options).ConsumeValueOrDie(); + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({computation_handle, handle()}); + std::vector outputs; + auto status = session.Run(inputs, {}, {release}, &outputs); + if (!status.ok()) { + LOG(ERROR) << status; + return; + } +} - return new LocalShapedBuffer(std::move(result_buffer)); +StatusOr CompiledXrtComputation::Execute( + absl::Span argument_handles) { + const int num_expected_arguments = program_shape().parameters().size(); + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + std::vector arguments; + arguments.reserve(num_expected_arguments); + for (int i = 0; i < num_expected_arguments; ++i) { + arguments.push_back( + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64)); + } + auto computation_handle = + tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto execution_config = + tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto execute = tensorflow::ops::XRTExecute(root, computation_handle, + execution_config, arguments); + TF_RETURN_IF_ERROR(root.status()); + + TF_RET_CHECK(argument_handles.size() == arguments.size()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(false); + e.set_release_compilation_handle(false); + + tensorflow::ClientSession session(root, session_target_); + tensorflow::ClientSession::FeedType inputs; + for (int i = 0; i < arguments.size(); ++i) { + inputs.insert({arguments[i], argument_handles[i]->handle()}); + } + inputs.insert({computation_handle, handle()}); + inputs.insert({execution_config, e.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {execute}, &outputs)); + + int64 output = outputs[0].scalar()(); + return new XrtAllocation(output, program_shape().result(), session_target_); } +const ProgramShape& CompiledXrtComputation::program_shape() const { + return program_shape_; +} + +int64 CompiledXrtComputation::handle() const { return handle_; } + LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} @@ -300,6 +543,37 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } +StatusOr LocalComputation::CompileForXrt( + const std::vector& argument_shapes, const string& session_target) { + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto program = tensorflow::ops::Placeholder(root, tensorflow::DT_STRING); + auto compile = tensorflow::ops::XRTCompile(root, program); + TF_RETURN_IF_ERROR(root.status()); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + ProgramShape shapes; + for (auto& shape : argument_shapes) { + *shapes.add_parameters() = shape; + } + TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); + LayoutUtil::SetToDefaultLayout(&shapes); + *config->mutable_program_shape() = shapes.ToProto(); + auto snapshot = computation().Snapshot().ValueOrDie(); + *c.mutable_hlo_snapshot() = *snapshot; + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + inputs.insert({program, c.SerializeAsString()}); + std::vector outputs; + TF_RETURN_IF_ERROR(session.Run(inputs, {compile.handle}, &outputs)); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation().GetProgramShape()); + int64 handle = outputs[0].scalar()(); + return new CompiledXrtComputation(program_shape, handle, session_target); +} + const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -314,9 +588,7 @@ string LocalComputation::GetSerializedProto() const { } StatusOr LocalComputation::GetReturnValueShape() const { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - computation_.GetProgramShape()); - return std::move(*program_shape.mutable_result()); + return swig::GetReturnValueShape(computation_); } LocalOp::LocalOp(const XlaOp& op) : op_(op) {} @@ -343,6 +615,12 @@ LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, return xla::Parameter(&builder_, parameter_number, shape, name); } +StatusOr LocalComputationBuilder::BuildWithRoot( + const LocalOp& root) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build(root.op())); + return new LocalComputation(std::move(computation)); +} + StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { return builder_.GetShape(operand.op()); } @@ -371,6 +649,12 @@ LocalOp LocalComputationBuilder::Broadcast( return xla::Broadcast(operand.op(), broadcast_sizes); } +LocalOp LocalComputationBuilder::BroadcastInDim( + const LocalOp& operand, absl::Span out_dim_sizes, + absl::Span broadcast_dimensions) { + return xla::BroadcastInDim(operand.op(), out_dim_sizes, broadcast_dimensions); +} + LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config) { @@ -532,10 +816,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, @@ -569,13 +856,13 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { } LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { - return xla::Sort(operand.op(), absl::nullopt, dimension); + return xla::Sort(operand.op(), {}, dimension); } LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, const LocalOp& values, int64 dimension) { - return xla::Sort(keys.op(), values.op(), dimension); + return xla::Sort(keys.op(), {values.op()}, dimension); } StatusOr LocalComputationBuilder::BuildConstantSubGraph( @@ -674,23 +961,29 @@ void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { delete local_shaped_buffer; } +void DeleteXrtAllocation(XrtAllocation* allocation) { delete allocation; } + void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { delete computation; } +void DeleteCompiledXrtComputation(CompiledXrtComputation* computation) { + delete computation; +} + void DeleteLocalComputation(LocalComputation* computation) { delete computation; } StatusOr DestructureLocalShapedBufferTuple( LocalShapedBuffer* local_shaped_buffer) { - if (!ShapeUtil::IsTuple( - local_shaped_buffer->shaped_buffer()->on_device_shape())) { + const Shape tuple_shape = local_shaped_buffer->shape(); + + if (!ShapeUtil::IsTuple(tuple_shape)) { return InvalidArgument( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", - ShapeUtil::HumanString( - local_shaped_buffer->shaped_buffer()->on_device_shape())); + ShapeUtil::HumanString(tuple_shape)); } DeviceMemoryAllocator* allocator = @@ -702,7 +995,6 @@ StatusOr DestructureLocalShapedBufferTuple( int device_ordinal = tuple_buffer.device_ordinal(); ShapeTree& shape_tree = tuple_buffer.buffers(); - const Shape& tuple_shape = tuple_buffer.on_device_shape(); std::vector results; for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { // Create a shaped buffer for this destructured tuple element. @@ -730,5 +1022,47 @@ StatusOr DestructureLocalShapedBufferTuple( return new LocalShapedBufferTuple(std::move(results)); } +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target) { + const Shape& tuple_shape = allocation->shape(); + + if (!ShapeUtil::IsTuple(tuple_shape)) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString(tuple_shape)); + } + + tensorflow::Scope root = tensorflow::Scope::NewRootScope(); + auto base_handle = tensorflow::ops::Placeholder(root, tensorflow::DT_INT64); + auto shape_index = tensorflow::ops::Placeholder(root, tensorflow::DT_INT32); + auto subtuple = tensorflow::ops::XRTSubTuple(root, base_handle, shape_index); + TF_RETURN_IF_ERROR(root.status()); + + tensorflow::ClientSession session(root, session_target); + tensorflow::ClientSession::FeedType inputs; + std::vector results; + for (int32 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + inputs.clear(); + inputs.insert({base_handle, allocation->handle()}); + inputs.insert({shape_index, {i}}); + std::vector outputs; + auto status = session.Run(inputs, {subtuple}, &outputs); + if (!status.ok()) { + // Clean up before returning non-ok status. + for (int j = 0; j < results.size(); ++j) { + delete results[j]; + } + return status; + } + const int64 subtuple_handle = outputs[0].scalar()(); + const Shape& subtuple_shape = + ShapeUtil::GetTupleElementShape(tuple_shape, i); + results.push_back( + new XrtAllocation(subtuple_handle, subtuple_shape, session_target)); + } + return new XrtAllocationTuple(std::move(results)); +} + } // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 2166bb6721ca380f3180a8802e4922f2e9e45945..149e44570df5c6a3df88bbe2ffa779be47842d82 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -16,7 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#include +#include + #include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -34,6 +39,12 @@ namespace swig { // returned. Status InitializeReplicaCount(int replica_count); +// Initializes the platform name that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializePlatformName(const string& platform_name); + // Returns the replica count that is currently set, regardless of whether the // local XLA service has been instantiated yet or not. int GetReplicaCount(); @@ -54,18 +65,19 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, int replica_number); -// Wraps a ScopedShapedBuffer produced by copying a literal "to -// device," i.e. copying a literal to a scoped buffer via the local -// client. +// Represents a reference to literals that live in a device-allocated buffer via +// XLA. Specifically, wraps a ScopedShapedBuffer produced by transferring a +// literal to device via the local client. class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, const absl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout, + int replica_number); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); - const ScopedShapedBuffer* shaped_buffer() const; - StatusOr ToLiteral() const; + const Shape& shape() const; + const ScopedShapedBuffer* shaped_buffer() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -92,7 +104,7 @@ class LocalShapedBufferTuple { StatusOr Release(int i); // Returns the number of elements in the destructured tuple. - int size() const; + int64 size() const; private: std::vector elements_; @@ -103,31 +115,99 @@ class LocalShapedBufferTuple { StatusOr DestructureLocalShapedBufferTuple( LocalShapedBuffer* local_shaped_buffer); -// Wraps a LocalExecutable produced by compiling a -// LocalComputation. The Execute method forwards to that of the -// underlying LocalExecutable, and additionally handles tranferring -// arguments and return values in and back out of the client library's -// local client. This class is intended to be made available to Python -// via SWIG. +// Represents a reference to literals that live in a device-allocated buffer via +// XRT. Specifically, wraps an int64 handle produced by running the allocation +// graph, and an XLA shape to track the referent's shape. +class XrtAllocation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which allocation and deallocation + // graphs are run. + static StatusOr FromLiteral(const Literal& argument, + const string& session_target); + + XrtAllocation(int64 handle, Shape shape, const string& session_target); + ~XrtAllocation(); + StatusOr ToLiteral() const; + const Shape& shape() const; + const int64 handle() const; + + private: + const int64 handle_; + const Shape shape_; + const string session_target_; +}; + +// Result of a tuple destructuring operation on an XrtAllocation. +class XrtAllocationTuple { + public: + // Note: any XrtAllocation elements that are not Release()'d will be + // deallocated in the destructor. + explicit XrtAllocationTuple(std::vector elements); + + ~XrtAllocationTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int64 size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued XrtAllocation into its constitutent elements +// in XrtAllocationTuple form. +// +// Accepts a `session_target` argument, used in constructing the +// `tensorflow::ClientSession` instance in which the sub-tupling graph is run, +// and passed along in constructing each constituent XrtAllocation. +StatusOr DestructureXrtAllocationTuple( + XrtAllocation* allocation, const string& session_target); + +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XLA LocalExecutable. class CompiledLocalComputation { public: CompiledLocalComputation(std::unique_ptr executable); - // Execute the computation with the given argument literals, and - // with optionally-specified argument layouts. The literals will be - // re-laid out according to the corresponding elements of - // shapes_with_layout. - StatusOr Execute( - const std::vector& arguments, - const std::vector >& shapes_with_layout); - - LocalShapedBuffer* ExecuteWithShapedBuffers( + StatusOr Execute( absl::Span argument_handles); + // Execute on many replicas. Takes a sequence of argument lists (one argument + // list per replica) and returns a tuple of results (one result per replica). + // The number of argument lists must be equal to the replica count. + StatusOr ExecutePerReplica( + absl::Span > argument_handles); + private: std::unique_ptr executable_; }; +// Represents a compiled computation that can be executed given handles to +// device-allocated literals. Specifically, wraps an XRT computation handle. +class CompiledXrtComputation { + public: + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the execution graph is run. + CompiledXrtComputation(const ProgramShape& program_shape, int64 handle, + const string& session_target); + ~CompiledXrtComputation(); + + StatusOr Execute( + absl::Span argument_handles); + + const ProgramShape& program_shape() const; + int64 handle() const; + + private: + const ProgramShape program_shape_; + const int64 handle_; + const string session_target_; +}; + // Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be @@ -140,6 +220,11 @@ class LocalComputation { const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); + // Accepts a `session_target` argument, used in constructing the + // `tensorflow::ClientSession` instance in which the compilation graph is run. + StatusOr CompileForXrt( + const std::vector& argument_shapes, const string& session_target); + const XlaComputation& computation() const; // Returns the HloModuleProto contained in the XlaComputation in the @@ -183,6 +268,9 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); + // Returns an owned LocalComputation to the caller on success with given root. + StatusOr BuildWithRoot(const LocalOp& root); + LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); @@ -201,6 +289,10 @@ class LocalComputationBuilder { LocalOp Broadcast(const LocalOp& operand, absl::Span broadcast_sizes); + LocalOp BroadcastInDim(const LocalOp& operand, + absl::Span out_dim_sizes, + absl::Span broadcast_dimensions); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config); @@ -278,6 +370,8 @@ class LocalComputationBuilder { const LocalComputation& local_computation, absl::Span window_dimensions, absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, @@ -389,7 +483,9 @@ class LocalComputationBuilder { // Functions for freeing resources from the Python side. void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); +void DeleteXrtAllocation(XrtAllocation* allocation); void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); +void DeleteCompiledXrtComputation(CompiledXrtComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 521490e76c138553c5cc6895412eadb35a939881..d23d693c1e5bde43b52959e4397aa311268411bb 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -176,6 +176,81 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} +// Basic types + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + SWIG_fail; + } + Py_INCREF(Py_None); + $result = Py_None; +} + +%typemap(in) absl::Span + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + SWIG_fail; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + SWIG_fail; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Computation builder types + +%typemap(in) absl::Span( + std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + SWIG_fail; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { + SWIG_fail; + } + temps.push_back(*op); + Py_DECREF(o); + } + $1 = temps; +} + +// Computation and buffer/allocation types + %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -189,12 +264,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalShapedBuffer*) + $typemap(out, xla::swig::CompiledXrtComputation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -202,12 +277,12 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalShapedBufferTuple*) + $typemap(out, xla::swig::LocalShapedBuffer*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -215,23 +290,25 @@ tensorflow::ImportNumpy(); } } - -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - Literal value = $1.ConsumeValueOrDie(); - $result = numpy::PyObjectFromXlaLiteral(*value); + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBufferTuple*) + } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::LocalComputation*) + $typemap(out, xla::swig::XrtAllocation*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -239,92 +316,86 @@ tensorflow::ImportNumpy(); } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::XrtAllocationTuple*) + } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } } -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { - $result = PyBool_FromLong($1.ConsumeValueOrDie()); + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalComputation*) + } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } } -%typemap(out) Status { - if (!$1.ok()) { - PyErr_SetString( - PyExc_RuntimeError, $1.ToString().c_str()); - SWIG_fail; - } - Py_INCREF(Py_None); - $result = Py_None; -} - -// Span - -%typemap(in) absl::Span - (std::vector temps) { +%typemap(in) absl::Span + (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); + temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - Py_DECREF(o); - SWIG_fail; - } - temps[i] = numpy::PyIntOrPyLongToLong(py_int); - if (temps[i] == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - Py_DECREF(py_int); + temps.push_back(lsbp); Py_DECREF(o); } $1 = temps; } -// Span - -%typemap(in) absl::Span( - std::vector temps) { +%typemap(in) absl::Span > + (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); + temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - LocalOp* op; - if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), - SWIG_POINTER_EXCEPTION)) == -1) { - SWIG_fail; + std::vector vec; + const int vec_size = PySequence_Size(o); + vec.reserve(vec_size); + for (int j = 0; j < vec_size; ++j) { + PyObject* vec_elt = PySequence_GetItem(o, j); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + Py_DECREF(vec_elt); + Py_DECREF(o); + SWIG_fail; + } + vec.push_back(lsbp); + Py_DECREF(vec_elt); } - temps.push_back(*op); + temps.push_back(vec); Py_DECREF(o); } $1 = temps; } -// LocalShapedBuffer* - -%typemap(in) absl::Span - (std::vector temps) { +%typemap(in) absl::Span + (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; @@ -333,12 +404,12 @@ tensorflow::ImportNumpy(); temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - LocalShapedBuffer* lsbp; - if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + XrtAllocation* xrta; + if ((SWIG_ConvertPtr(o, (void**) &xrta, $descriptor(xla::swig::XrtAllocation*), SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps.push_back(lsbp); + temps.push_back(xrta); Py_DECREF(o); } $1 = temps; @@ -346,6 +417,16 @@ tensorflow::ImportNumpy(); // Literal +%typemap(out) StatusOr { + if ($1.ok()) { + Literal value = $1.ConsumeValueOrDie(); + $result = numpy::PyObjectFromXlaLiteral(*value); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + %typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { @@ -401,6 +482,19 @@ tensorflow::ImportNumpy(); // Shape +%typemap(out) const Shape& { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + %typemap(in) const Shape& (Shape temp) { StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { @@ -858,22 +952,22 @@ tensorflow::ImportNumpy(); $1 = NULL; } else { if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { - build_options.set_generate_hlo_graph(std::move(s)); + build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { - build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { - build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s)); })) { return nullptr; } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { - build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s)); })) { return nullptr; } @@ -887,7 +981,7 @@ tensorflow::ImportNumpy(); PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); SWIG_fail; } - build_options.set_hlo_profile(o == Py_True); + build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -914,6 +1008,7 @@ tensorflow::ImportNumpy(); %unignore xla; %unignore xla::swig; %unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::InitializePlatformName; %unignore xla::swig::GetReplicaCount; %unignore xla::swig::TransferToInfeedLocal; %unignore xla::swig::TransferToInfeedLocalReplica; @@ -921,20 +1016,32 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBuffer::shape; %unignore xla::swig::LocalShapedBufferTuple; %unignore xla::swig::LocalShapedBufferTuple::Release; %unignore xla::swig::LocalShapedBufferTuple::size; +%unignore xla::swig::XrtAllocation; +%unignore xla::swig::XrtAllocation::FromLiteral; +%unignore xla::swig::XrtAllocation::ToLiteral; +%unignore xla::swig::XrtAllocation::shape; +%unignore xla::swig::XrtAllocationTuple; +%unignore xla::swig::XrtAllocationTuple::Release; +%unignore xla::swig::XrtAllocationTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; -%unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; +%unignore xla::swig::CompiledLocalComputation::ExecutePerReplica; +%unignore xla::swig::CompiledXrtComputation; +%unignore xla::swig::CompiledXrtComputation::Execute; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputation::CompileForXrt; %unignore xla::swig::LocalComputation::GetReturnValueShape; %unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; +%unignore xla::swig::LocalComputationBuilder::BuildWithRoot; %unignore xla::swig::LocalComputationBuilder::SetOpMetadata; %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; %unignore xla::swig::LocalComputationBuilder::Parameter; @@ -945,6 +1052,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; %unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::BroadcastInDim; %unignore xla::swig::LocalComputationBuilder::Pad; %unignore xla::swig::LocalComputationBuilder::Reshape; %unignore xla::swig::LocalComputationBuilder::Collapse; @@ -1036,10 +1144,13 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Imag; %unignore xla::swig::LocalComputationBuilder::Conj; %unignore xla::swig::LocalComputationBuilder::Complex; +%unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DestructureLocalShapedBufferTuple; +%unignore xla::swig::DestructureXrtAllocationTuple; %unignore xla::swig::DeleteLocalShapedBuffer; -%unignore xla::swig::DeleteLocalComputation; +%unignore xla::swig::DeleteXrtAllocation; %unignore xla::swig::DeleteCompiledLocalComputation; +%unignore xla::swig::DeleteCompiledXrtComputation; %thread; %include "tensorflow/compiler/xla/python/local_computation_builder.h" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index bb303c5678a2cac9a9e78925e857ab25c0c6d9be..c91a2aaf56dfe2127168628c78e0c4b868a28055 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -26,6 +26,9 @@ import os import numpy as np +import six +from six.moves import xrange + from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api from tensorflow.compiler.xla.service import hlo_pb2 @@ -46,6 +49,15 @@ _OP_METADATA_FIELDS = [ OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) +class BackendType(enum.Enum): + XLA_LOCAL = 1 + XRT = 2 + + +BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target')) +XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local') + + def OpMetadataToProto(pyobj): proto = xla_data_pb2.OpMetadata() for field in _OP_METADATA_FIELDS: @@ -66,6 +78,13 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): source_line=lineno) +def _maybe_encode_string(s): + if six.PY3: + return s.encode('utf-8') + else: + return s + + class PaddingType(enum.Enum): VALID = 1 SAME = 2 @@ -73,16 +92,32 @@ class PaddingType(enum.Enum): def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, window_strides): - """Maps PaddingType (VALID or SAME) to pad values (list of pairs of ints).""" + """Maps PaddingType or string to pad values (list of pairs of ints).""" + if not isinstance(padding_type, (str, PaddingType)): + msg = 'padding_type must be str or PaddingType, got {}.' + raise TypeError(msg.format(type(padding_type))) + + if isinstance(padding_type, str): + if padding_type.upper() == 'VALID': + padding_type = PaddingType.VALID + elif padding_type.upper() == 'SAME': + padding_type = PaddingType.SAME + else: + msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' + raise ValueError(msg.format(padding_type)) + if padding_type == PaddingType.VALID: return [(0, 0)] * len(window_strides) - - out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) - pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) - for out_size, stride, filter_size, in_size - in zip(out_shape, window_strides, rhs_dims, lhs_dims)] - return [(pad_size // 2, pad_size - pad_size // 2) - for pad_size in pad_sizes] + elif padding_type == PaddingType.SAME: + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size + in zip(out_shape, window_strides, rhs_dims, lhs_dims)] + return [(pad_size // 2, pad_size - pad_size // 2) + for pad_size in pad_sizes] + else: + msg = 'Unexpected PaddingType value: {}' + raise ValueError(msg.format(padding_type)) _UNARY_OPS = [ @@ -187,38 +222,66 @@ class LocalBuffer(object): means the referent is in device memory. """ - def __init__(self, c_local_shaped_buffer): - self.c_local_shaped_buffer = c_local_shaped_buffer - self._delete = c_api.DeleteLocalShapedBuffer + def __init__(self, c_buffer, backend, replica): + self.c_buffer = c_buffer + self._backend = backend + self._replica = replica + if backend.backend_type == BackendType.XRT: + self._delete = c_api.DeleteXrtAllocation + else: + self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_pyval(pyval, layout_fn=None): + def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND): + """Allocate and copy to XLA the given python value.""" pyval = require_numpy_array_layout(pyval) - if layout_fn: - shape = Shape.from_pyval(pyval) - shape = shape.map_leaves(layout_fn) + num_replicas = get_replica_count() + if not 0 <= replica < num_replicas: + raise ValueError( + 'Attempt to place buffer on replica {} when the replica count is {}' + .format(replica, num_replicas)) + if backend.backend_type == BackendType.XRT: + if replica != 0: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + cbuf = c_api.XrtAllocation.FromLiteral( + pyval, _maybe_encode_string(backend.target)) else: - shape = None - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(pyval, shape)) + cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica) + return LocalBuffer(cbuf, backend, replica) def to_py(self): - return self.c_local_shaped_buffer.ToLiteral() + return self.c_buffer.ToLiteral() + + def shape(self): + return _wrap_shape(self.c_buffer.shape()) + + def replica(self): + return self._replica def delete(self): - if self.c_local_shaped_buffer is not None: - self._delete(self.c_local_shaped_buffer) - self.c_local_shaped_buffer = None + if self.c_buffer is not None: + self._delete(self.c_buffer) + self.c_buffer = None def destructure(self): - assert self.c_local_shaped_buffer is not None - result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) - self.c_local_shaped_buffer = None + """Assuming a tuple buffer, unpack it into constituent tuple elements.""" + assert self.c_buffer is not None + if self._backend.backend_type == BackendType.XRT: + result = c_api.DestructureXrtAllocationTuple( + self.c_buffer, _maybe_encode_string(self._backend.target)) + else: + result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer) + self.delete() size = result.size() - destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + destructured = tuple( + LocalBuffer( + result.Release(i), replica=self._replica, backend=self._backend) + for i in xrange(size)) return destructured def is_deleted(self): - return self.c_local_shaped_buffer is None + return self.c_buffer is None def __del__(self): self.delete() @@ -283,6 +346,9 @@ class Shape(object): def __ne__(self, other): return not self == other + def __hash__(self): + return hash((self._dtype, self._dimensions, self._minor_to_major)) + def __repr__(self): return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' '_is_tuple={!r}, _minor_to_major={!r})').format( @@ -436,26 +502,37 @@ class LocalComputation(object): ComputationBuilder methods. """ - def __init__(self, c_local_computation, is_compiled): - self.c_local_computation = c_local_computation - self.is_compiled = is_compiled + def __init__(self, c_computation, is_compiled, backend=XLA_LOCAL_BACKEND): + self._c_computation = c_computation + self._backend = backend + self._is_compiled = is_compiled # Ensure a reference to C-based destructor for use in __del__. if is_compiled: - assert isinstance(c_local_computation, c_api.CompiledLocalComputation) - self._delete = c_api.DeleteCompiledLocalComputation + if backend.backend_type == BackendType.XRT: + assert isinstance(c_computation, c_api.CompiledXrtComputation) + self._delete = c_api.DeleteCompiledXrtComputation + else: + assert isinstance(c_computation, c_api.CompiledLocalComputation) + self._delete = c_api.DeleteCompiledLocalComputation else: - assert isinstance(c_local_computation, c_api.LocalComputation) + assert isinstance(c_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + @property + def computation(self): + if self._is_compiled: + raise ValueError( + 'Attempt to read the XLA computation of a compiled LocalComputation.') + return self._c_computation + def GetProto(self): """Get the HloModuleProto proto object in this local computation. Returns: An HloModuleProto proto object that has the whole-graph information. """ - - serialized = self.c_local_computation.GetSerializedProto() + serialized = self.computation.GetSerializedProto() proto = hlo_pb2.HloModuleProto.FromString(serialized) return proto @@ -480,10 +557,10 @@ class LocalComputation(object): Returns: A newly *compiled* local computation instance. """ - if self.is_compiled: + if self._is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + result_shape = _wrap_shape(self.computation.GetReturnValueShape()) if layout_fn: argument_shapes = [ @@ -491,11 +568,16 @@ class LocalComputation(object): ] result_shape = result_shape.map_leaves(layout_fn) + argument_shapes = list(argument_shapes) + compile_options = compile_options or CompileOptions() compile_options.result_shape = result_shape - return LocalComputation( - self.c_local_computation.Compile(argument_shapes, compile_options), - is_compiled=True) + if self._backend.backend_type == BackendType.XRT: + c = self.computation.CompileForXrt( + argument_shapes, _maybe_encode_string(self._backend.target)) + else: + c = self.computation.Compile(argument_shapes, compile_options) + return LocalComputation(c, is_compiled=True, backend=self._backend) def CompileWithExampleArguments(self, arguments=(), @@ -506,33 +588,89 @@ class LocalComputation(object): compile_options=compile_options, layout_fn=layout_fn) - def Execute(self, arguments=(), layout_fn=None): - """Execute with Python values as arguments and return value.""" - if not self.is_compiled: + def GetReturnValueShape(self): + return _wrap_shape(self._c_computation.GetReturnValueShape()) + + def Execute(self, arguments=(), check_for_deleted_args=True): + """Execute on one replica with LocalBuffer arguments and return value.""" + if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + raw_args = [arg.c_buffer for arg in arguments] + output_buffer = self._c_computation.Execute(raw_args) + return LocalBuffer(output_buffer, backend=self._backend, replica=0) + + def ExecutePerReplica(self, arguments=None): + """Execute on many replicas with LocalBuffer arguments and return value. + + Args: + arguments: A sequence of sequences of LocalBuffers. The i'th inner + sequence comprises the arguments for execution on the i'th replica. + + Returns: + A list of the computation's outputs on each replica, as a LocalBuffer. If + a shallow sequence of arguments was passed in for `arguments`, then the + sole, zero'th replica's output is returned instead, as a LocalBuffer. + """ + if not self._is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - argument_shapes = [Shape.from_pyval(arg) for arg in arguments] - if layout_fn: - argument_shapes = [ - shape.map_leaves(layout_fn) for shape in argument_shapes - ] + if arguments is None: + arguments = ((),) * get_replica_count() + else: + arguments = [list(replica_args) for replica_args in arguments] + + # Check arguments + for replica, replica_args in enumerate(arguments): + for arg in replica_args: + if arg.is_deleted(): + raise ValueError('Executing with deleted local buffer argument') + if arg.replica() != replica: + raise ValueError( + 'Executing on replica {} with argument from replica {}'.format( + replica, arg.replica())) + + # Pull out argument buffer handles + stripped_args = [ + [arg.c_buffer for arg in replica_args] for replica_args in arguments + ] + + # Execute + if self._backend.backend_type == BackendType.XRT: + if len(stripped_args) > 1: + raise NotImplementedError( + 'Multi-replica execution is not yet supported via the XRT backend.') + output_buffers = [self._c_computation.Execute(stripped_args[0])] else: - argument_shapes = [None for shape in argument_shapes] - arguments = tuple(map(require_numpy_array_layout, arguments)) - return self.c_local_computation.Execute(arguments, argument_shapes) + output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args) + size = output_buffer_tup.size() + output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)] - def ExecuteWithLocalBuffers(self, arguments=()): - """Execute with LocalBuffer arguments and return value.""" - if not self.is_compiled: - raise ValueError('Cannot execute an uncompiled local XLA computation.') - arguments = tuple(arguments) - if any(arg.is_deleted() for arg in arguments): - raise ValueError('Executing with deleted local buffer argument') - return LocalBuffer( - self.c_local_computation.ExecuteWithShapedBuffers( - [arg.c_local_shaped_buffer for arg in arguments])) + # Wrap output handles in LocalBuffer instances + return tuple( + LocalBuffer(output_buffer, backend=self._backend, replica=replica) + for replica, output_buffer in enumerate(output_buffers)) + + def ExecuteWithPythonValues(self, arguments=()): + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): + return LocalBuffer.from_pyval(arg, backend=self._backend) + + arguments = [put(arg) for arg in arguments] + return self.Execute(arguments).to_py() + + def ExecuteWithPythonValuesPerReplica(self, arguments): + """Execute on many replicas with Python values as arguments and output.""" + + def put(arg, replica): + return LocalBuffer.from_pyval(arg, replica, backend=self._backend) + + arguments = [[put(arg, replica) + for arg in replica_args] + for replica, replica_args in enumerate(arguments)] + return [out.to_py() for out in self.ExecutePerReplica(arguments)] def __del__(self): - self._delete(self.c_local_computation) + self._delete(self._c_computation) class ComputationBuilder(object): @@ -554,8 +692,13 @@ class ComputationBuilder(object): self._client = c_api.LocalComputationBuilder(name.encode('utf8')) self._parameter_numbering = itertools.count() - def Build(self): - return LocalComputation(self._client.Build(), is_compiled=False) + def Build(self, root=None, backend=XLA_LOCAL_BACKEND): + if root is not None: + return LocalComputation( + self._client.BuildWithRoot(root), is_compiled=False, backend=backend) + else: + return LocalComputation( + self._client.Build(), is_compiled=False, backend=backend) def SetOpMetadata(self, op_metadata): """Set metadata for operations that are about to be enqueued.""" @@ -700,6 +843,20 @@ class ComputationBuilder(object): """ return self._client.Broadcast(operand, sizes) + def BroadcastInDim(self, operand, shape, broadcast_dimensions): + """Enqueues a broadcast-in-dimensions operation onto the computation. + + Args: + operand: the operand LocalOp to broadcast. + shape: tuple of integers, the expected output shape. + broadcast_dimensions: tuple of integers identifying which dimensions + of the output are to be broadcast into. + + Returns: + A LocalOp representing the added broadcast-in-dimensions op. + """ + return self._client.BroadcastInDim(operand, shape, broadcast_dimensions) + def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -834,8 +991,8 @@ class ComputationBuilder(object): padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) return self._client.SelectAndScatterWithGeneralPadding( - operand, select.c_local_computation, window_dimensions, window_strides, - pads, source, init_value, scatter.c_local_computation) + operand, select.computation, window_dimensions, window_strides, pads, + source, init_value, scatter.computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -943,7 +1100,7 @@ class ComputationBuilder(object): Returns: A LocalOp representing the added call op. """ - return self._client.Call(computation_to_apply.c_local_computation, operands) + return self._client.Call(computation_to_apply.computation, operands) def Map(self, operands, computation_to_apply, dimensions): """Enqueues a map operation onto the computation. @@ -956,7 +1113,7 @@ class ComputationBuilder(object): Returns: A LocalOp representing the added Map op. """ - return self._client.Map(operands, computation_to_apply.c_local_computation, + return self._client.Map(operands, computation_to_apply.computation, dimensions) def Reduce(self, operand, init_value, computation_to_apply, dimensions): @@ -972,8 +1129,7 @@ class ComputationBuilder(object): A LocalOp representing the added Reduce op. """ return self._client.Reduce(operand, init_value, - computation_to_apply.c_local_computation, - dimensions) + computation_to_apply.computation, dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): @@ -994,8 +1150,31 @@ class ComputationBuilder(object): padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) return self._client.ReduceWindowWithGeneralPadding( - operand, init_value, computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads) + operand, init_value, computation_to_apply.computation, + window_dimensions, window_strides, (), (), pads) + + def ReduceWindowWithGeneralPadding( + self, operand, init_value, computation_to_apply, window_dimensions, + window_strides, base_dilations, window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + A LocalOp representing the added ReduceWindow op. + """ + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. @@ -1039,8 +1218,7 @@ class ComputationBuilder(object): Returns: a LocalOp representing the While operation. """ - return self._client.While(cond.c_local_computation, - body.c_local_computation, init) + return self._client.While(cond.computation, body.computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): @@ -1056,8 +1234,8 @@ class ComputationBuilder(object): Returns: a LocalOp representing the Conditional operation. """ return self._client.Conditional( - pred, true_operand, true_computation.c_local_computation, false_operand, - false_computation.c_local_computation) + pred, true_operand, true_computation.computation, false_operand, + false_computation.computation) def IsConstant(self, operand): """Checks whether the given operand is a compile-time constant. @@ -1124,10 +1302,9 @@ class ComputationBuilder(object): pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) - dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), - (), dimension_numbers, - feature_group_count) + return self.ConvGeneralDilated( + lhs, rhs, window_strides, pads, (), (), + dimension_numbers=None, feature_group_count=feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count=1): @@ -1145,11 +1322,9 @@ class ComputationBuilder(object): Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ - dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers, - feature_group_count) + return self.ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers=None, feature_group_count=feature_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1167,7 +1342,7 @@ class ComputationBuilder(object): return dimension_numbers def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers, + rhs_dilation, dimension_numbers=None, feature_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. @@ -1178,10 +1353,11 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of integer dilation factors. rhs_dilation: length-N array-like of integer dilation factors. - dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a - triple (lhs_spec, rhs_spec, out_spec) where each element is a string of - length N+2 identifying by position (1) batch dimensions in lhs, rhs, and - the output with the character 'N', (2) feature dimensions in lhs and the + dimension_numbers: optional, either an + xla_data_pb2.ConvolutionDimensionNumbers proto instance or a tuple + (lhs_spec, rhs_spec, out_spec) where each element is a string of length + N+2 identifying by position (1) batch dimensions in lhs, rhs, and the + output with the character 'N', (2) feature dimensions in lhs and the output with the character 'C', (3) input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and (4) spatial dimension correspondences between lhs, rhs, and the output using any @@ -1194,13 +1370,16 @@ class ComputationBuilder(object): spatial dimension character labels according to the order in which the labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character - appearing in rhs_spec that is not 'I' or 'O'. + appearing in rhs_spec that is not 'I' or 'O'. By default, use the same + dimension numbering as Conv and ConvWithGeneralPadding. feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the ConvGenralDilated operation. """ - if not isinstance(dimension_numbers, - xla_data_pb2.ConvolutionDimensionNumbers): + if dimension_numbers is None: + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + elif not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): lhs_spec, rhs_spec, out_spec = dimension_numbers dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() @@ -1285,6 +1464,19 @@ def initialize_replica_count(replica_count): c_api.InitializeReplicaCount(replica_count) +def initialize_platform_name(platform_name): + """Initializes the desired platform name to use on XLA service init. + + Args: + platform_name: string name of platform. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + platform_name = _maybe_encode_string(platform_name) + c_api.InitializePlatformName(platform_name) + + def get_replica_count(): """Returns the current replica count used for the XLA service. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 82103f03132e45ff822ce1ebcc2be47b24f5869f..21b5c93b615ec429a5da0b4ffe89e8f75f59ef1b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -37,7 +37,7 @@ class LocalComputationTest(unittest.TestCase): def _Execute(self, c, arguments): compiled_c = c.Build().CompileWithExampleArguments(arguments) - return compiled_c.Execute(arguments) + return compiled_c.ExecuteWithPythonValues(arguments) def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): assert expected is not None @@ -355,7 +355,7 @@ class LocalBufferTest(LocalComputationTest): def _Execute(self, c, arguments): compiled_c = c.Build().CompileWithExampleArguments(arguments) arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] - result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) + result_buffer = compiled_c.Execute(arg_buffers) return result_buffer.to_py() def testConstantSum(self): @@ -388,7 +388,7 @@ class LocalBufferTest(LocalComputationTest): arg_buffer = xla_client.LocalBuffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(ValueError): - compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + compiled_c.Execute([arg_buffer]) def testDestructureTupleEmpty(self): t = () @@ -439,6 +439,13 @@ class LocalBufferTest(LocalComputationTest): np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + def testShape(self): + pyval = np.array([[1., 2.]], np.float32) + local_buffer = xla_client.LocalBuffer.from_pyval(pyval) + xla_shape = local_buffer.shape() + self.assertEqual(xla_shape.dimensions(), (1, 2,)) + self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -478,7 +485,7 @@ class SingleOpTest(LocalComputationTest): x = c.Constant(np.array(template, dtype=src_dtype)) c.ConvertElementType(x, xla_types[dst_dtype]) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() expected = np.array(template, dtype=dst_dtype) self.assertEqual(result.shape, expected.shape) @@ -505,7 +512,7 @@ class SingleOpTest(LocalComputationTest): x = c.Constant(np.array(template, dtype=src_dtype)) c.BitcastConvertType(x, dst_etype) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() expected = np.array(template, src_dtype).view(dst_dtype) self.assertEqual(result.shape, expected.shape) @@ -987,7 +994,7 @@ class SingleOpTest(LocalComputationTest): c.Tuple( c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), c.Constant(NumpyArrayBool([True, False, False, True]))) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() self.assertIsInstance(result, tuple) np.testing.assert_equal(result[0], 42) np.testing.assert_allclose(result[1], [1.0, 2.0]) @@ -1007,12 +1014,19 @@ class SingleOpTest(LocalComputationTest): self._ExecuteAndCompareExact( c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + def testBroadcastInDim(self): + c = self._NewComputation() + c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) + self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]]) + c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) + self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]]) + def testRngNormal(self): shape = (2, 3) c = self._NewComputation() c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), dims=shape) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() # since the result is random, we just check shape and uniqueness self.assertEqual(result.shape, shape) self.assertEqual(len(np.unique(result)), np.prod(shape)) @@ -1023,7 +1037,7 @@ class SingleOpTest(LocalComputationTest): c = self._NewComputation() c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), dims=shape) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() # since the result is random, we just check shape, uniqueness, and range self.assertEqual(result.shape, shape) self.assertEqual(len(np.unique(result)), np.prod(shape)) @@ -1036,7 +1050,7 @@ class SingleOpTest(LocalComputationTest): c = self._NewComputation() c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), dims=shape) - result = c.Build().Compile().Execute() + result = c.Build().Compile().ExecuteWithPythonValues() # since the result is random, we just check shape, integrality, and range self.assertEqual(result.shape, shape) self.assertEqual(result.dtype, np.int32) @@ -1473,7 +1487,7 @@ class EmbeddedComputationsTest(LocalComputationTest): xla_client.transfer_to_infeed(item) for item in to_infeed: - result = compiled_c.Execute() + result = compiled_c.ExecuteWithPythonValues() self.assertEqual(result, item) def testInfeedThenOutfeedS32(self): @@ -1511,5 +1525,20 @@ class ErrorTest(LocalComputationTest): lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) +class ComputationRootTest(LocalComputationTest): + """Tests related to setting the root of the computation.""" + + def testComputationRootDifferentFromLastOp(self): + c = self._NewComputation() + x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) + result = c.Add(x, c.ConstantF32Scalar(3.14)) + extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable + + arg = NumpyArrayF32(1.0) + compiled_c = c.Build(result).CompileWithExampleArguments([arg]) + ans = compiled_c.ExecuteWithPythonValues([arg]) + np.testing.assert_allclose(ans, 4.14) + + if __name__ == "__main__": unittest.main() diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py index f158f6b2410352432445f669155aff0af5526abf..95b2bf300ec67e9f034f77450416544cb088ae55 100644 --- a/tensorflow/compiler/xla/python_api/xla_shape.py +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -25,9 +25,10 @@ from tensorflow.compiler.xla.python_api import types class Shape(object): - """Wraps a xla_data_pb2.Shape message with a convenient Python type. + """Wraps a xla_data_pb2.ShapeProto message with a convenient Python type. - Provides direct access to the underlying xla_data_pb2.Shape message in the + Provides direct access to the underlying xla_data_pb2.ShapeProto message in + the message attribute, along with accessor wrappers to the message's fields. Avoid direct access to .message unless interacting directly with protobuf APIs like CopyFrom. In other words, prefer hauling the shape around in a Shape, and @@ -48,7 +49,7 @@ class Shape(object): Raises: ValueError: if element_type is TUPLE but dimensions are not Shape objects. """ - self.message = xla_data_pb2.Shape() + self.message = xla_data_pb2.ShapeProto() self.message.element_type = element_type if element_type == xla_data_pb2.TUPLE: if not all(isinstance(subshape, Shape) for subshape in dimensions): diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 3abb3855a42b8b5222115262448d359da3a80e87..26affbcceb33110baf41d507173e56f8b1c8c9eb 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -16,7 +16,6 @@ xla_proto_library( use_grpc_plugin = True, visibility = ["//visibility:public"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 4e1435fa30a24c320ddbedb84d37b369a3158a54..d8123a6de28ca532819ece4a75cd0b725f8c1bbd 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -47,11 +47,18 @@ namespace xla { }); } -::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +::grpc::Status GRPCService::Compile(::grpc::ServerContext* /*context*/, + const CompileRequest* arg, + CompileResponse* result) { return DelegateRPC( - [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); + [this, arg, result]() { return service_->Compile(arg, result); }); +} + +::grpc::Status GRPCService::Execute(::grpc::ServerContext* /*context*/, + const ExecuteRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Execute(arg, result); }); } ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index ca1b09b648013ad45d806040c5ddcf11d9e5604e..3e586b288a56a22573d0c3b9ae7b2f25fdbf851a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -39,9 +39,13 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, - const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + ::grpc::Status Compile(::grpc::ServerContext* context, + const CompileRequest* arg, + CompileResponse* result) override; + + ::grpc::Status Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 7b8ab158e1396d7087a407be180ab44d2e16e121..66abf66cfd6c2f753c5507aa373452ac880e9a29 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,10 +62,17 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Compile(const CompileRequest* request, + CompileResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteGraph(context, *request, response); + return grpc_stub_->Compile(context, *request, response); + }); +} + +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Execute(context, *request, response); }); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 8dfcb761387d608abbb1f62974f49b976a7ff7ff..f02b401399f3e895153f0b08e325bc9c2c2336ec 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,8 +43,11 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status Compile(const CompileRequest* request, + CompileResponse* response) override; + + Status Execute(const ExecuteRequest* request, + ExecuteResponse* response) override; Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 551ae895e05586daec0ffcd425f4950f76bdd50d..0ff8adc2acbe5fd21e85027dd63bfb14f5672a7d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -43,7 +43,6 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; package xla; @@ -128,11 +127,14 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. + // Compiles the provided computation into executable. Returns the handle of + // the executable. + rpc Compile(CompileRequest) returns (CompileResponse) {} + + // Invokes the provided executable with the provided global data passed as + // immutable arguments. The request contains the handle to the executable. // Returns global data output and execution timing. - rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { - } + rpc Execute(ExecuteRequest) returns (ExecuteResponse) {} // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4797cf333070f2dd371e81c01ad659151cbc216d..4c21ae2a427477caa86fb4130616c38eb3bcf006 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -124,7 +123,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -158,12 +156,12 @@ tf_cc_test( ":bfloat16_propagation", ":bfloat16_support", ":hlo", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -253,6 +251,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", @@ -280,12 +279,14 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -293,7 +294,9 @@ cc_library( name = "hlo", srcs = [ "dfs_hlo_visitor.cc", + "dynamic_parameter_binding.cc", "hlo_computation.cc", + "hlo_input_output_alias_config.cc", "hlo_instruction.cc", "hlo_instructions.cc", "hlo_module.cc", @@ -305,9 +308,11 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "dynamic_parameter_binding.h", "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", + "hlo_input_output_alias_config.h", "hlo_instruction.h", "hlo_instructions.h", "hlo_module.h", @@ -320,7 +325,6 @@ cc_library( ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", - ":hlo_reachability", ":name_uniquer", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal", @@ -350,6 +354,25 @@ cc_library( ], ) +tf_cc_test( + name = "dynamic_parameter_binding_test", + srcs = ["dynamic_parameter_binding_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + tf_cc_test( name = "dfs_hlo_visitor_with_default_test", srcs = ["dfs_hlo_visitor_with_default_test.cc"], @@ -362,7 +385,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -388,9 +410,36 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pattern_matcher_gmock", + testonly = 1, + hdrs = ["pattern_matcher_gmock.h"], + deps = [ + ":pattern_matcher", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "pattern_matcher_gmock_test", + srcs = ["pattern_matcher_gmock_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -399,10 +448,12 @@ cc_library( srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], @@ -417,7 +468,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -463,7 +513,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -516,7 +565,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -565,7 +613,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -588,7 +635,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -600,11 +646,11 @@ cc_library( hdrs = ["platform_util.h"], deps = [ ":compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -644,6 +690,7 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", + ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -659,6 +706,7 @@ cc_library( ":source_map_util", ":stream_pool", ":transfer_manager", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -670,7 +718,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", @@ -727,12 +774,12 @@ cc_library( ":computation_layout", ":platform_util", ":service", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -808,6 +855,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", + "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", ], @@ -830,6 +878,7 @@ cc_library( ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", @@ -837,7 +886,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", @@ -858,6 +906,7 @@ cc_library( ":executable", ":hlo", ":hlo_module_config", + ":hlo_module_group", ":logical_buffer", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1082,7 +1131,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1099,6 +1147,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_reachability", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1164,7 +1213,6 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1268,6 +1316,25 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_input_output_alias_config_test", + srcs = ["hlo_input_output_alias_config_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_memory_scheduler", srcs = ["hlo_memory_scheduler.cc"], @@ -1320,6 +1387,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1339,6 +1407,7 @@ cc_library( ":fusion_queue", ":hlo", ":hlo_pass", + ":hlo_reachability", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", @@ -1364,6 +1433,7 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":hlo_reachability", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", @@ -1404,7 +1474,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -1480,7 +1549,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1523,7 +1591,10 @@ tf_cc_test( ":hlo", ":hlo_casting_utils", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1532,7 +1603,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1569,7 +1639,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1619,7 +1688,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1671,6 +1740,19 @@ cc_library( ], ) +tf_cc_test( + name = "while_loop_analysis_test", + srcs = ["while_loop_analysis_test.cc"], + deps = [ + ":while_loop_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1679,9 +1761,11 @@ cc_library( ":call_inliner", ":hlo", ":hlo_pass", + ":hlo_query", + ":pattern_matcher", ":while_loop_analysis", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", @@ -1693,10 +1777,17 @@ tf_cc_test( name = "while_loop_simplifier_test", srcs = ["while_loop_simplifier_test.cc"], deps = [ + ":algebraic_simplifier", + ":hlo", + ":hlo_cse", + ":hlo_dce", ":hlo_matchers", + ":hlo_pass", + ":hlo_pass_pipeline", + ":tuple_simplifier", ":while_loop_simplifier", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -1727,7 +1818,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1755,7 +1846,7 @@ tf_cc_test( ":implicit_broadcast_remover", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -1800,7 +1891,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1820,6 +1910,41 @@ cc_library( ], ) +cc_library( + name = "dynamic_dimension_inference", + srcs = ["dynamic_dimension_inference.cc"], + hdrs = ["dynamic_dimension_inference.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "dynamic_dimension_inference_test", + srcs = ["dynamic_dimension_inference_test.cc"], + deps = [ + ":dynamic_dimension_inference", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "reshape_mover_test", srcs = ["reshape_mover_test.cc"], @@ -1834,7 +1959,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1977,7 +2102,8 @@ tf_cc_test( srcs = ["hlo_computation_test.cc"], deps = [ ":hlo", - ":hlo_matchers", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -2240,7 +2366,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2303,13 +2428,27 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", ], ) +cc_library( + name = "compilation_cache", + srcs = ["compilation_cache.cc"], + hdrs = ["compilation_cache.h"], + deps = [ + ":executable", + ":hlo_module_config", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "layout_assignment", srcs = [ @@ -2379,14 +2518,13 @@ tf_cc_test( ":hlo_graph_dumper", ":hlo_matchers", ":hlo_runner", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2450,6 +2588,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2503,7 +2642,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2559,8 +2697,9 @@ tf_cc_test( ":algebraic_simplifier", ":computation_layout", ":hlo", - ":hlo_matchers", ":layout_assignment", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -2570,8 +2709,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/types:span", @@ -2632,7 +2771,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2673,7 +2812,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2707,12 +2845,13 @@ tf_cc_test( ":hlo_matchers", ":hlo_parser", ":hlo_pass", + ":pattern_matcher", + ":pattern_matcher_gmock", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2784,10 +2923,9 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/memory", @@ -2820,6 +2958,46 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_get_dimension_size_rewriter", + srcs = ["hlo_get_dimension_size_rewriter.cc"], + hdrs = ["hlo_get_dimension_size_rewriter.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "hlo_get_dimension_size_rewriter_test", + srcs = ["hlo_get_dimension_size_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_get_dimension_size_rewriter", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "device_memory_allocator", srcs = [ @@ -2878,6 +3056,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", @@ -2975,7 +3154,6 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -2992,6 +3170,7 @@ cc_library( ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", + ":pattern_matcher", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -3126,6 +3305,7 @@ cc_library( ":buffer_assignment", ":hlo", ":hlo_proto", + ":hlo_verifier", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", ], @@ -3188,6 +3368,7 @@ cc_library( ":computation_placer", ":executable", ":hlo", + ":hlo_module_group", ":transfer_manager", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -3215,6 +3396,7 @@ cc_library( ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", + "@com_google_absl//absl/strings", ], ) @@ -3251,6 +3433,8 @@ cc_library( ":tuple_util", "//tensorflow/compiler/xla:literal_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", ], ) @@ -3277,10 +3461,11 @@ cc_library( ":hlo", ":hlo_pass", ":tuple_util", + ":while_loop_analysis", ":while_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -3296,7 +3481,7 @@ tf_cc_test( ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3326,7 +3511,7 @@ tf_cc_test( ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", ], ) @@ -3339,6 +3524,7 @@ cc_library( ":bfloat16_normalization", ":defuser", ":hlo", + ":hlo_memory_scheduler", ":hlo_pass", ":hlo_pass_pipeline", ":implicit_broadcast_remover", @@ -3386,7 +3572,7 @@ tf_cc_test( ":indexed_array_analysis", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:test", ], @@ -3422,6 +3608,9 @@ tf_cc_test( ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -3471,6 +3660,41 @@ cc_library( ], ) +cc_library( + name = "ar_crs_combiner", + srcs = ["ar_crs_combiner.cc"], + hdrs = ["ar_crs_combiner.h"], + deps = [ + ":call_graph", + ":pattern_matcher", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "ar_crs_combiner_test", + srcs = ["ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "map_inliner_test", srcs = ["map_inliner_test.cc"], @@ -3482,7 +3706,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 75dae7a7141647d7b7b60b0e07e11c143621ea63..985c5af1c4d89425dd6693585e42e22510fe21f8 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include +#include +#include #include #include #include @@ -67,6 +69,45 @@ bool IsAll(const HloInstruction* op, int8 value) { } } +// Checks whether `op` is a floating-point constant or broadcast of a constant +// of the form +/- 2^k for some integer k positive, negative, or zero. Such +// values are interesting because multiplying by a power of 2 just moves the +// exponent. +bool IsAllFpConstantPowerOf2(const HloInstruction* op) { + // Unwrap the broadcast if necessary. + const HloInstruction* c; + if (!Match(op, m::ConstantEffectiveScalar(&c)) && + !Match(op, m::Broadcast(m::Constant(&c).WithShape( + m::Shape().IsEffectiveScalar())))) { + return false; + } + auto val = [&]() -> absl::optional { + switch (c->shape().element_type()) { + case BF16: + return static_cast(c->literal().GetFirstElement()); + case F16: + return static_cast(c->literal().GetFirstElement()); + case F32: + return c->literal().GetFirstElement(); + case F64: + return c->literal().GetFirstElement(); + default: + // Cowardly refuse to consider complex types. + return absl::nullopt; + } + }(); + if (!val) { + return false; + } + + int exp; + double mantissa = std::frexp(*val, &exp); + // frexp returns a value in the range (-1; -0.5] U [0.5, 1). A return value + // of +/-0.5 therefore indicates that the floating point value is a power of + // 2. + return mantissa == 0.5 || mantissa == -0.5; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -83,7 +124,8 @@ bool TransposeIsBitcast(const HloInstruction* transpose) { // reshape may still be a bitcast. For example, a reshape from [28x28] to [784]. bool ReshapeOrCopyIsBitcast( const HloInstruction* instr, - const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) { + const AlgebraicSimplifierOptions::ValidBitcastCallback& + valid_bitcast_callback) { CHECK(HloOpcode::kReshape == instr->opcode() || HloOpcode::kCopy == instr->opcode()); @@ -94,6 +136,11 @@ bool ReshapeOrCopyIsBitcast( valid_bitcast_callback(operand->shape(), instr->shape()); } +bool IsUnstridedSlice(const HloInstruction* hlo) { + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); +} + // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -107,6 +154,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add) override; + Status HandleAnd(HloInstruction* logical_and) override; + Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBitcastConvert(HloInstruction* bitcast) override; @@ -141,6 +190,12 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) override; + Status HandleNegate(HloInstruction* negate) override; + + Status HandleNot(HloInstruction* logical_not) override; + + Status HandleOr(HloInstruction* logical_or) override; + Status HandlePad(HloInstruction* pad) override; Status HandlePower(HloInstruction* power) override; @@ -157,6 +212,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -169,21 +226,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const bool changed() const { return changed_; } // Runs the visitor on a computation. - static bool Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification); + static bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options); private: - explicit AlgebraicSimplifierVisitor( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) - : computation_(computation), - is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + explicit AlgebraicSimplifierVisitor(HloComputation* computation, + const AlgebraicSimplifierOptions& options) + : computation_(computation), options_(options) {} // Transforms Dots where at least one input is a vector or has a degenerate // dimension and converts it into a multiply and reduce. This should enable @@ -222,10 +271,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* new_instruction); // Returns whether the shape of the output of the given instructions are the - // same for the purposes of simplification. If is_layout_sensitive_ is true, - // then this tests shape equality including layout (ShapeUtil::Equal). If - // is_layout_sensitive_ is false, then the tests shape compatibility - // (ShapeUtil::Compatible). + // same for the purposes of simplification. If options_.is_layout_sensitive() + // is true, then this tests shape equality including layout + // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the + // tests shape compatibility (ShapeUtil::Compatible). bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; // Returns whether it was possible to transform `root` to a clamp instruction. @@ -304,26 +353,22 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Tries to use a kDot in place of the given convolution. StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Tries to simplify a slice where the result of the slice is a scalar. + StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + + // Tries to convert slice(reshape(X)) into reshape(slice(X)) + StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; + // The backend-specific options selected for the algebraic simplifier. + const AlgebraicSimplifierOptions& options_; + // Whether algebraic simplification has occurred. bool changed_ = false; - // Whether layout is considered during transformation. - bool is_layout_sensitive_; - - // Callback used to determine if a bitcast is possible. - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - - // Disable dot strength reduction on platforms where it causes a slowdown. - bool enable_dot_strength_reduction_; - - // Disable convolution -> dot simplification on platforms where it causes a - // slowdown. - bool enable_conv_simplification_; - // Cached computation for adding two scalar F32. HloComputation* scalar_add_computation_ = nullptr; }; @@ -331,19 +376,15 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { } // namespace bool AlgebraicSimplifierVisitor::Run( - HloComputation* computation, bool is_layout_sensitive, - AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction, bool enable_conv_simplification) { - AlgebraicSimplifierVisitor visitor( - computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_strength_reduction, enable_conv_simplification); + HloComputation* computation, const AlgebraicSimplifierOptions& options) { + AlgebraicSimplifierVisitor visitor(computation, options); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const { - if (is_layout_sensitive_) { + if (options_.is_layout_sensitive()) { return ShapeUtil::Equal(lhs->shape(), rhs->shape()); } else { return ShapeUtil::Compatible(lhs->shape(), rhs->shape()); @@ -414,6 +455,77 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { sum_of_constants)); } + // A*C + B*C => (A+B)*C + // + // - If A, B, and C are integers, do this unconditionally. Proof of + // correctness: https://rise4fun.com/Alive/u9X. + // + // - If A, B, and C are floating point, do this if C is a scalar constant or + // broadcast of scalar constant and is equal to +/- 2^k for some (possibly + // negative) integer k. + // + // Multiplying by a power of 2 just moves the exponent, so our answer is + // exact modulo rounding of intermediate results so long as + // + // - none of the three products has an exponent which underflows (so the + // result is 0 or denormal), and + // - none of the three products overflows to inf. + // + // Proof: See algebraic_simplifier_proof_distributive_property.py. + // + // We deem these differences in rounding, underflow, and overflow + // acceptable in the ML context. + HloInstruction *b, *c; + if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) || + (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) && + Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) && + (ShapeUtil::ElementIsIntegral(add->shape()) || + IsAllFpConstantPowerOf2(c))) { + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary( + add->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + add->shape(), HloOpcode::kAdd, a, b)), + c)); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); + // Simplify logical and + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A && True => A + VLOG(10) << "trying transform [A && True => A]: " + << logical_and->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + // True && A => A + VLOG(10) << "trying transform [True && A => A]: " + << logical_and->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // A && False => False + VLOG(10) << "trying transform [A && False => False]: " + << logical_and->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_and, rhs)) { + return Status::OK(); + } + + // False && A => False + VLOG(10) << "trying transform [False && A => False]: " + << logical_and->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_and, lhs)) { + return Status::OK(); + } + } + return Status::OK(); } @@ -450,8 +562,8 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { return Status::OK(); } - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(copy, options_.valid_bitcast_callback())) { ReplaceWithBitcast(copy); } @@ -487,7 +599,74 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( VLOG(10) << "trying to replace " << concatenate->ToString() << " with " << replacement->ToString(); ReplaceInstructionIfSameShape(concatenate, replacement); - } else if (operands.size() == 2) { + return Status::OK(); + } + + // Check if we can merge "adjacent" slice operands which take slices from the + // same other op. For simplicity we only merge unstrided slices. + int64 concatenate_dimension = concatenate->concatenate_dimension(); + for (int64 i = 0; i < operands.size(); ++i) { + if (operands[i]->opcode() != HloOpcode::kSlice || + !IsUnstridedSlice(operands[i])) { + continue; + } + int64 slice_end = operands[i]->slice_limits(concatenate_dimension); + HloInstruction* slice_operand = operands[i]->mutable_operand(0); + int64 j = i + 1; + while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice && + IsUnstridedSlice(operands[j]) && + operands[j]->operand(0) == slice_operand && + operands[j]->slice_starts(concatenate_dimension) == slice_end) { + // Check that all the slice_start values are the same in all other + // dimensions. This implies that the slice_limit values are also the same, + // because operands of concatenate need to have the same shape, and we + // already checked that the slices are unstrided. + bool same_other_starts = true; + for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) { + if (k == concatenate_dimension) { + continue; + } + if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) { + same_other_starts = false; + break; + } + } + if (!same_other_starts) { + break; + } + slice_end = operands[j]->slice_limits(concatenate_dimension); + ++j; + } + if (j - i > 1) { + Shape new_slice_shape = operands[i]->shape(); + new_slice_shape.set_dimensions( + concatenate_dimension, + slice_end - operands[i]->slice_starts(concatenate_dimension)); + auto new_limit_indices = operands[i]->slice_limits(); + new_limit_indices[concatenate_dimension] = slice_end; + auto new_slice_op = + computation_->AddInstruction(HloInstruction::CreateSlice( + new_slice_shape, slice_operand, + /*start_indices=*/operands[i]->slice_starts(), + /*limit_indices=*/new_limit_indices, + /*strides=*/operands[i]->slice_strides())); + std::vector new_operands; + for (int64 k = 0; k < i; ++k) { + new_operands.push_back(operands[k]); + } + new_operands.push_back(new_slice_op); + for (int64 k = j; k < operands.size(); ++k) { + new_operands.push_back(operands[k]); + } + auto replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), new_operands)); + ReplaceInstructionIfSameShape(concatenate, replacement); + return Status::OK(); + } + } + + if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. bool is_effective_low_pad = Match( @@ -503,7 +682,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( padding_config_dim->set_edge_padding_high(0); padding_config_dim->set_edge_padding_low(0); padding_config_dim->set_interior_padding(0); - if (dim == concatenate->concatenate_dimension()) { + if (dim == concatenate_dimension) { if (is_effective_low_pad) { padding_config_dim->set_edge_padding_low( operands[0]->shape().dimensions(dim)); @@ -1161,7 +1340,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_gather_optimized); } - if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + if (options_.enable_dot_strength_reduction() && + !options_.is_layout_sensitive()) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); if (did_strength_reduction) { @@ -1223,6 +1403,64 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { + // negate(negate(x)) => x + HloInstruction* x; + if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) && + ReplaceInstructionIfSameShape(negate, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) { + // not(not(x)) => x + HloInstruction* x; + if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) && + ReplaceInstructionIfSameShape(logical_not, x)) { + return Status::OK(); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) { + HloInstruction *lhs, *rhs; + CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs)))); + + // Simplify logical or + if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) && + ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) { + // A || True => True + VLOG(10) << "trying transform [A || True => True]: " + << logical_or->ToString(); + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + // True || A => True + VLOG(10) << "trying transform [True || A => True]: " + << logical_or->ToString(); + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // A || False => A + VLOG(10) << "trying transform [A || False => A]: " + << logical_or->ToString(); + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(logical_or, lhs)) { + return Status::OK(); + } + + // False || A => A + VLOG(10) << "trying transform [False || A => A]: " + << logical_or->ToString(); + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(logical_or, rhs)) { + return Status::OK(); + } + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); @@ -1507,6 +1745,27 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); } + + // Interior padding on one sized dimensions have no effect. As a result it + // makes other simplifications possible if there is no interior padding. + if (HasInteriorPadding(pad->padding_config())) { + PaddingConfig padding_config = pad->padding_config(); + bool cleared_interior_padding = false; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + if (padding_config.dimensions(i).interior_padding() > 0 && + pad->operand(0)->shape().dimensions(i) == 1) { + cleared_interior_padding = true; + padding_config.mutable_dimensions(i)->set_interior_padding(0); + } + } + if (cleared_interior_padding) { + return ReplaceWithNewInstruction( + pad, + HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0), + pad->mutable_operand(1), padding_config)); + } + } + // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1798,8 +2057,8 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } // Make this a bitcast if possible. - if (is_layout_sensitive_ && - ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) { + if (options_.is_layout_sensitive() && + ReshapeOrCopyIsBitcast(reshape, options_.valid_bitcast_callback())) { ReplaceWithBitcast(reshape); return Status::OK(); } @@ -1820,18 +2079,165 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifyScalarSlice( + HloInstruction* slice) { + // Only try to do this for effective scalars. We could do the same for slicing + // out larger pieces of padding (replacing with a broadcast of the padding + // value), but this is probably not worth it. + if (!ShapeUtil::IsEffectiveScalar(slice->shape())) { + return false; + } + + if (slice->operand(0)->opcode() == HloOpcode::kPad) { + VLOG(10) << "Trying to simplify scalar slice of pad"; + // Check there's no internal padding. Again, we could handle that too, since + // everything is statically known, but it's not worth it. + auto pad = Cast(slice->mutable_operand(0)); + auto padding_config = pad->padding_config(); + int64 rank = padding_config.dimensions_size(); + if (HasInteriorPadding(padding_config)) { + VLOG(10) << "Not folding scalar slice of pad, pad has interior padding"; + return false; + } + + // Check whether the scalar we're slicing out falls into the padding. + bool in_padding = [&]() { + for (int64 i = 0; i < rank; ++i) { + int64 start = slice->slice_starts(i); + int64 low = padding_config.dimensions(i).edge_padding_low(); + int64 data = pad->operand(0)->shape().dimensions(i); + if (start >= low && start < low + data) { + return false; + } + } + return true; + }(); + + if (in_padding) { + VLOG(10) << "Folding scalar slice of pad into padding value"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), + pad->mutable_padding_value()))); + return true; + } else { + // We already know the output of the slice is scalar. If the padded + // value is scalar, and it's not in the padding, then it's exactly the + // output value. + bool replaced = + ReplaceInstructionIfSameShape(slice, pad->mutable_operand(0)); + if (replaced) { + VLOG(10) << "Folding scalar slice of pad into padded value"; + } else { + VLOG(10) << "Not folding scalar slice of pad into padded value as they " + "have different shapes."; + } + return replaced; + } + } + + if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) { + VLOG(10) << "Trying to simplify scalar slice of concat"; + // Only do this for R1, there's no chance of this being useful otherwise. + if (ShapeUtil::Rank(slice->shape()) != 1) { + VLOG(10) << "Not folding, slice is not rank 1"; + return false; + } + HloConcatenateInstruction* concat = + Cast(slice->mutable_operand(0)); + int64 operand_start = 0; + int64 operand_num = 0; + // Weird loop structure to avoid annoying off-by-one errors. + while (true) { + TF_RET_CHECK(operand_num < concat->operand_count()); + const HloInstruction* operand = concat->operand(operand_num); + int64 next_operand_start = operand_start + operand->shape().dimensions(0); + if (next_operand_start > slice->slice_starts(0)) { + break; + } + operand_start = next_operand_start; + operand_num++; + } + + bool replaced = ReplaceInstructionIfSameShape( + slice, concat->mutable_operand(operand_num)); + if (replaced) { + VLOG(10) << "Folding scalar slice of concat into concat operand"; + } else { + VLOG(10) << "Folding scalar slice of concat into slice of concat operand"; + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateSlice( + slice->shape(), concat->mutable_operand(operand_num), + {slice->slice_starts(0) - operand_start}, + {slice->slice_starts(0) - operand_start + 1}, + slice->slice_strides()))); + } + return true; + } + + return false; +} + +StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( + HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!IsUnstridedSlice(slice)) { + return false; + } + HloInstruction* reshape = slice->mutable_operand(0); + if (reshape->opcode() != HloOpcode::kReshape) { + return false; + } + HloInstruction* new_slice_operand = reshape->mutable_operand(0); + int64 slice_rank = ShapeUtil::Rank(slice->shape()); + std::vector sliced_dims; + for (int64 i = 0; i < slice_rank; ++i) { + if (slice->slice_starts(i) != 0 || + slice->slice_limits(i) != reshape->shape().dimensions(i)) { + sliced_dims.push_back(i); + } + } + + if (sliced_dims.size() == 1 && sliced_dims[0] == 0 && + slice->slice_starts(0) == 0) { + const Shape& new_slice_shape = new_slice_operand->shape(); + const int64 rank = ShapeUtil::Rank(new_slice_shape); + std::vector new_slice_starts(rank, 0); + std::vector new_slice_stides(rank, 1); + std::vector new_slice_limits(new_slice_shape.dimensions().begin(), + new_slice_shape.dimensions().end()); + int64 slice_elements = ShapeUtil::ElementsIn(slice->shape()); + for (int64 i = rank - 1; i >= 0; --i) { + if (slice_elements >= new_slice_limits[i]) { + if (slice_elements % new_slice_limits[i] != 0) { + return false; + } + slice_elements /= new_slice_limits[i]; + } else { + new_slice_limits[i] = slice_elements; + slice_elements = 1; + } + } + HloInstruction* new_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(new_slice_shape.element_type(), + new_slice_limits), + new_slice_operand, new_slice_starts, new_slice_limits, + new_slice_stides)); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReshape(slice->shape(), new_slice))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Delete no-op slices, i.e. where shape = operand shape. if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { return Status::OK(); } - auto is_unstrided_slice = [](const HloInstruction* hlo) { - return absl::c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); - }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && - is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { + IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) { HloInstruction* operand_slice = slice->mutable_operand(0); std::vector new_slice_starts = slice->slice_starts(); std::vector new_slice_limits = slice->slice_limits(); @@ -1844,6 +2250,16 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { slice->shape(), operand_slice->mutable_operand(0), new_slice_starts, new_slice_limits, slice->slice_strides())); } + + TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice)); + if (replaced) { + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + if (replaced) { + return Status::OK(); + } return Status::OK(); } @@ -2057,6 +2473,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr @@ -2193,6 +2615,22 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( /*reduce_computation=*/function)); } +Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { + // select(x, y, y) -> y. + if (select->operand(1) == select->operand(2)) { + return ReplaceInstruction(select, select->mutable_operand(1)); + } + // select(true, x, y) -> x. + if (IsAll(select->operand(0), true)) { + return ReplaceInstruction(select, select->mutable_operand(1)); + } + // select(false, x, y) -> y. + if (IsAll(select->operand(0), false)) { + return ReplaceInstruction(select, select->mutable_operand(2)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { auto operand = sort->mutable_operand(0); int64 dimension_to_sort = sort->dimensions(0); @@ -2203,7 +2641,109 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } // If it is key/value sort, the output of sort is a tuple. return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + sort, HloInstruction::CreateTuple(sort->operands())); + } + if (!options_.enable_permutation_sort_replacement()) { + return Status::OK(); + } + // Check if we are sorting a permutation. In that case, we know that the keys + // will be sorted to the identity permutation, and we can represent the + // changes to the 'values' parameter as a scatter. + if (sort->operand_count() == 2 && + operand->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* other_sort = operand->operand(0); + // Check whether the 'values' parameter is the result of another sort with + // the same sort dimension. + if (other_sort->opcode() == HloOpcode::kSort && + other_sort->operand_count() >= 2 && + other_sort->dimensions(0) == dimension_to_sort && + other_sort->operand(operand->tuple_index())->opcode() == + HloOpcode::kIota) { + auto* iota = + Cast(other_sort->operand(operand->tuple_index())); + // The sort operand needs to be an integral iota, and the iota dimension + // needs to be the dimension that was sorted. + if (iota->iota_dimension() == dimension_to_sort && + ShapeUtil::ElementIsIntegral(iota->shape())) { + // We use the following construction method for a Scatter that applies + // the permutation from 'keys' to the 'values' parameter. + // - Take the "keys" parameter of the second sort and reshape it to have + // another "1" dimension at the end. + // - Concatenate it with iotas of the same extended shape with all + // different iota_dimensions except the dimension_to_sort in the order + // of iota_dimensions/dimension_to_sort, so e.g. with rank 3 and + // dimension_to_sort = 1, we would have concatenate of (iota with + // iota_dimension=0, keys, iota with iota_dimension = 2) + // - Use this as the indices parameter of scatter, and set updates + // of the scatter to be a reshaped 'values' parameter of sort (adding + // 'rank' many 1 dimensions at the end). + int64 rank = ShapeUtil::Rank(operand->shape()); + Shape extended_shape = operand->shape(); + extended_shape.add_dimensions(1); + extended_shape.mutable_layout()->add_minor_to_major(rank); + auto reshaped_permutation = computation_->AddInstruction( + HloInstruction::CreateReshape(extended_shape, operand)); + std::vector concat_operands; + for (int64 i = 0; i < rank; ++i) { + if (i == dimension_to_sort) { + concat_operands.push_back(reshaped_permutation); + } else { + concat_operands.push_back(computation_->AddInstruction( + HloInstruction::CreateIota(extended_shape, i))); + } + } + Shape concat_shape = operand->shape(); + concat_shape.add_dimensions(rank); + concat_shape.mutable_layout()->add_minor_to_major(rank); + auto scatter_indices = + rank > 1 ? computation_->AddInstruction( + HloInstruction::CreateConcatenate( + concat_shape, concat_operands, rank)) + : reshaped_permutation; + + // We don't care about the operand, it will be completely overridden by + // the updates. + auto scatter_operand = computation_->AddInstruction( + HloInstruction::CreateIota(sort->operand(1)->shape(), 0)); + + // Construct the updates operand of scatter. + Shape update_shape = sort->operand(1)->shape(); + for (int64 i = 0; i < rank; ++i) { + update_shape.add_dimensions(1); + update_shape.mutable_layout()->add_minor_to_major(rank + i); + } + auto scatter_updates = + computation_->AddInstruction(HloInstruction::CreateReshape( + update_shape, sort->mutable_operand(1))); + + // Construct the updates computation, which simply replaces the operand + // values with the update values. + HloComputation::Builder b("update_replace_computation"); + Shape scalar_shape = ShapeUtil::MakeShape(S32, {}); + b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "scalar_rhs")); + auto update_replace_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_rhs)); + + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(rank); + for (int64 i = 0; i < rank; ++i) { + dim_numbers.add_update_window_dims(rank + i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + auto scatter = + computation_->AddInstruction(HloInstruction::CreateScatter( + sort->operand(1)->shape(), scatter_operand, scatter_indices, + scatter_updates, update_replace_computation, dim_numbers)); + return ReplaceWithNewInstruction( + sort, HloInstruction::CreateTuple( + {computation_->AddInstruction(HloInstruction::CreateIota( + operand->shape(), dimension_to_sort)), + scatter})); + } + } } return Status::OK(); } @@ -2229,7 +2769,7 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return ReplaceInstruction(transpose, operand); } - if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { + if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); } @@ -2378,13 +2918,13 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (!enable_conv_simplification_) { + if (!options_.enable_conv_simplification()) { return false; } // TODO(b/31337498): For now, we cowardly refuse to do this optimization in // layout-insensitive mode, for fear of adding nontrivial reshapes. - if (!is_layout_sensitive_) { + if (!options_.is_layout_sensitive()) { return false; } @@ -2474,9 +3014,9 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be // invalid. - if (!valid_bitcast_callback_(input_shape, new_input_shape) || - !valid_bitcast_callback_(filter_shape, new_filter_shape) || - !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { + if (!options_.valid_bitcast_callback()(input_shape, new_input_shape) || + !options_.valid_bitcast_callback()(filter_shape, new_filter_shape) || + !options_.valid_bitcast_callback()(dot_output_shape, convolution_shape)) { return false; } @@ -2582,9 +3122,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run( - comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_strength_reduction_, enable_conv_simplification_)) { + if (AlgebraicSimplifierVisitor::Run(comp, options_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f8d0ee88bdebcf17310cd0407b1b99e4b0a7b5f..d2775b9fafa7e4c625f5d181114e80e7369f9c78 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,8 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloModulePass { +class AlgebraicSimplifierOptions { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform @@ -34,18 +33,63 @@ class AlgebraicSimplifier : public HloModulePass { using ValidBitcastCallback = std::function; + explicit AlgebraicSimplifierOptions( + ValidBitcastCallback valid_bitcast_callback) + : valid_bitcast_callback_(std::move(valid_bitcast_callback)) {} + // If valid_bitcast_callback returns true, then the pass will replace reshapes + // and transposes with bitcasts. + const ValidBitcastCallback& valid_bitcast_callback() const { + return valid_bitcast_callback_; + } + + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + void set_is_layout_sensitive(bool is_layout_sensitive) { + is_layout_sensitive_ = is_layout_sensitive; + } + bool is_layout_sensitive() const { return is_layout_sensitive_; } + + // Enable dot simplification on platforms where it is profitable. + void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { + enable_dot_strength_reduction_ = enable_dot_strength_reduction; + } + bool enable_dot_strength_reduction() const { + return enable_dot_strength_reduction_; + } + + // Enable convolution simplification on platforms where it is profitable. + void set_enable_conv_simplification(bool enable_conv_simplification) { + enable_conv_simplification_ = enable_conv_simplification; + } + bool enable_conv_simplification() const { + return enable_conv_simplification_; + } + + // If enable_permutation_sort_replacement is true, a sort op that is known to + // sort a permutation will be replaced with a scatter op. + void set_enable_permutation_sort_replacement( + bool enable_permutation_sort_replacement) { + enable_permutation_sort_replacement_ = enable_permutation_sort_replacement; + } + bool enable_permutation_sort_replacement() const { + return enable_permutation_sort_replacement_; + } + + private: + ValidBitcastCallback valid_bitcast_callback_; + bool is_layout_sensitive_{false}; + bool enable_dot_strength_reduction_{true}; + bool enable_conv_simplification_{true}; + bool enable_permutation_sort_replacement_{false}; +}; + +// A pass which performs algebraic simplifications. +class AlgebraicSimplifier : public HloModulePass { + public: // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. If valid_bitcast_callback - // returns true, then the pass will replace reshapes and transposes with - // bitcasts. - AlgebraicSimplifier(bool is_layout_sensitive, - ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_strength_reduction = true, - bool enable_conv_simplification = true) - : is_layout_sensitive_(is_layout_sensitive), - valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_strength_reduction_(enable_dot_strength_reduction), - enable_conv_simplification_(enable_conv_simplification) {} + // transformation. Otherwise, layout is ignored. + explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) + : options_(options) {} ~AlgebraicSimplifier() override = default; absl::string_view name() const override { return "algsimp"; } @@ -54,14 +98,7 @@ class AlgebraicSimplifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; - ValidBitcastCallback valid_bitcast_callback_; - - // Enable dot simplification on platforms where it is profitable. - bool enable_dot_strength_reduction_; - - // Enable convolution simplification on platforms where it is profitable. - bool enable_conv_simplification_; + AlgebraicSimplifierOptions options_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py new file mode 100644 index 0000000000000000000000000000000000000000..5da13da041b4ded813876af7ca379025187545ab --- /dev/null +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_proof_distributive_property.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Proof that transforming (A*C)+(B*C) <=> (A+B)*C is "safe" if C=2^k. + +Specifically, for all floating-point values A, B, and C, if + + - C is equal to +/- 2^k for some (possibly negative) integer k, and + - A, B, C, A*C, B*C, and A+B are not subnormal, zero, or inf, + +then there exists a rounding mode rm in [RTZ, RNE] such that + + (A*C) + (B*C) == (A+B) * C (computed with rounding mode rm). + +Informally, this means that the equivalence holds for powers of 2 C, modulo +flushing to zero or inf, and modulo rounding of intermediate results. + +Requires z3 python bindings; try `pip install z3-solver`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import z3 + +# We do float16 because it lets the solver run much faster. These results +# should generalize to fp32 and fp64, and you can verify this by changing the +# value of FLOAT_TY (and then waiting a while). +FLOAT_TY = z3.Float16 + +a = z3.FP("a", FLOAT_TY()) +b = z3.FP("b", FLOAT_TY()) +c = z3.FP("c", FLOAT_TY()) + +s = z3.Solver() + +# C must be a power of 2, i.e. significand bits must all be 0. +s.add(z3.Extract(FLOAT_TY().sbits() - 1, 0, z3.fpToIEEEBV(c)) == 0) + +for rm in [z3.RTZ(), z3.RNE()]: + z3.set_default_rounding_mode(rm) + before = a * c + b * c + after = (a + b) * c + + # Check that before == after, allowing that 0 == -0. + s.add( + z3.Not( + z3.Or( + before == after, # + z3.And(z3.fpIsZero(before), z3.fpIsZero(after))))) + + for x in [ + (a * c), + (b * c), + (a + b), + ]: + s.add(z3.Not(z3.fpIsSubnormal(x))) + s.add(z3.Not(z3.fpIsZero(x))) + s.add(z3.Not(z3.fpIsInf(x))) + +if s.check() == z3.sat: + m = s.model() + print("Counterexample found!") + print(m) + print("a*c: ", z3.simplify(m[a] * m[c])) + print("b*c: ", z3.simplify(m[b] * m[c])) + print("a+b: ", z3.simplify(m[a] + m[b])) + print("a*c + b*c: ", z3.simplify(m[a] * m[c] + m[b] * m[c])) + print("(a+b) * c: ", z3.simplify((m[a] + m[b]) * m[c])) +else: + print("Proved!") diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2047f894b465816eb97ba205e79033bd52bf7a0c..14ce519b6a0fd221070006d336d23bddeb6cd621 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -27,13 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -43,21 +44,24 @@ namespace xla { namespace { using ::testing::ElementsAre; +namespace m = match; -namespace op = xla::testing::opcode_matchers; - -AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } -AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { +AlgebraicSimplifierOptions::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloTestBase { + protected: + AlgebraicSimplifierOptions default_options_{non_bitcasting_callback()}; +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -67,18 +71,140 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, FactorIntegerAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[8] parameter(0) + p1 = s32[8] parameter(1) + p2 = s32[8] parameter(2) + x = s32[8] multiply(p0, p2) + y = s32[8] multiply(p1, p2) + ROOT sum = s32[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), m::Parameter(2)))); +} + +// A*C + B*C => (A+B)*C if C is a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.125) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::ConstantScalar(0.125)))); +} + +// A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + x = f32[4] multiply(p0, b) + y = f32[4] multiply(p1, b) + ROOT sum = f32[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + +// A*C + B*C => (A+B)*C simplification should not happen if C is not a +// floating-point power of 2. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionNotPowerOf2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(0.3) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, c) + ROOT sum = f32[] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification should not happen if A, B, and C are +// complex numbers. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionComplex) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = c64[8] parameter(0) + p1 = c64[8] parameter(1) + p2 = c64[8] parameter(2) + x = c64[8] multiply(p0, p2) + y = c64[8] multiply(p1, p2) + ROOT sum = c64[8] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); +} + +// A*C + B*C => (A+B)*C simplification is OK if A, B, and C are complex. +TEST_F(AlgebraicSimplifierTest, FactorFpAdditionBfloat16) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = bf16[4] parameter(0) + p1 = bf16[4] parameter(1) + c = bf16[] constant(0.125) + b = bf16[4] broadcast(c), dimensions={} + x = bf16[4] multiply(p0, b) + y = bf16[4] multiply(p1, b) + ROOT sum = bf16[4] add(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)), + m::Broadcast(m::ConstantScalar(0.125))))); +} + // Test that A * 0 is simplified to 0 TEST_F(AlgebraicSimplifierTest, MulZero) { + auto m = CreateNewVerifiedModule(); Shape r0s32 = ShapeUtil::MakeShape(S32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -88,17 +214,81 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), zero); } +// Test that select(true, a, b) is simplified to a +TEST_F(AlgebraicSimplifierTest, SelectTrue) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, one, param0, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param0); +} + +// Test that select(false, a, b) is simplified to b +TEST_F(AlgebraicSimplifierTest, SelectFalse) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, zero, param0, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param1); +} + +// Test that select(a, b, b) is simplified to b +TEST_F(AlgebraicSimplifierTest, SelectIdentical) { + Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0s32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0s32, "param1")); + builder.AddInstruction(HloInstruction::CreateTernary( + r0s32, HloOpcode::kSelect, param0, param1, param1)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSelect); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_EQ(computation->root_instruction(), param1); +} + // Test that Reduce(Reduce(A)) -> Reduce(A) TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloInstruction* zero = builder.AddInstruction( @@ -113,7 +303,7 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); HloInstruction* param = builder.AddInstruction( @@ -126,17 +316,17 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, dims1, add_computation)); - module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Reduce(param, zero)); + m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reduce(m::Parameter(0), m::Op().Is(zero)))); EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); } // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -146,18 +336,18 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), m::Constant()))); } // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -172,17 +362,19 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); + EXPECT_THAT(root, GmockMatch(m::Add( + m::Op().Is(param0), + m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -194,17 +386,17 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create add computation. HloComputation* add_computation = nullptr; @@ -217,7 +409,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); HloInstruction* param0 = builder.AddInstruction( @@ -230,17 +422,18 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { HloInstruction::CreateBroadcast(r2f32, zero, {}))}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kMap); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Op().Is(zero))))); } TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -252,64 +445,64 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14f, 3.14f, 3.14f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); } TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({3.14, 3.14, 4}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); } TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(root, GmockMatch(m::Constant())); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); } // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -319,18 +512,18 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A - Const is canonicalized to A + (-Const). TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -340,18 +533,19 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kSubtract, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(0), + m::Negate(m::Op().Is(constant))))); } // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -365,21 +559,24 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Divide(param0, param1), param2)); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Multiply(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/(B/C) is simplified to (A*C)/B. TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -393,21 +590,25 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Divide(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Divide(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Multiply(param0, param2), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(2)), + m::Parameter(1)))); } // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -425,23 +626,25 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); + GmockMatch(m::Divide(m::Divide(m::Parameter(0), m::Parameter(1)), + m::Divide(m::Parameter(2), m::Parameter(3))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + GmockMatch(m::Divide(m::Multiply(m::Parameter(0), m::Parameter(3)), + m::Multiply(m::Parameter(1), m::Parameter(2))))); } // Test that A/exp(B) is simplified to A*exp(-B). TEST_F(AlgebraicSimplifierTest, DivOfExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -453,21 +656,22 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Exp(param1))); + GmockMatch(m::Divide(m::Parameter(0), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Exp(op::Negate(param1)))); + GmockMatch(m::Multiply(m::Parameter(0), + m::Exp(m::Negate(m::Parameter(1)))))); } // Test that A/pow(B,C) is simplified to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfPower) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -481,22 +685,26 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // Test that broadcasting is done on the right step when simplifying A/pow(B,C) // to A*pow(B,-C). TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -510,21 +718,25 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(param0, op::Power(param1, param2))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Parameter(0), + m::Power(m::Parameter(1), m::Parameter(2))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + GmockMatch(m::Multiply( + m::Parameter(0), + m::Power(m::Parameter(1), m::Negate(m::Parameter(2)))))); } // A / Const => A * InvertedConst TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -535,18 +747,18 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(param0, op::Constant())); + GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -560,17 +772,19 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, inner_power, exp2)); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Power(base, op::Multiply(exp1, exp2))); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Power(m::Op().Is(base), + m::Multiply(m::Op().Is(exp1), m::Op().Is(exp2))))); } // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex // numbers. TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + auto m = CreateNewVerifiedModule(); Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); HloComputation::Builder builder(TestName()); HloInstruction* base = builder.AddInstruction( @@ -584,14 +798,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, inner_power, exp2)); - module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(m.get()).ValueOrDie()); } // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -601,18 +815,18 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that A/1 is simplified to A for an array. TEST_F(AlgebraicSimplifierTest, DivOneArray) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -622,18 +836,18 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that complex(real(c), imag(c)) is simplified to c. TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); HloComputation::Builder builder(TestName()); @@ -646,18 +860,18 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -670,18 +884,18 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } // Test that imag(complex(r,i)) is simplified to i. TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -694,18 +908,18 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -721,18 +935,18 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Add(param1, param2)); + EXPECT_THAT(root, GmockMatch(m::Add(m::Parameter(1), m::Parameter(2)))); } // Test that exp(A)/exp(B) is simplified to exp(A-B) TEST_F(AlgebraicSimplifierTest, ExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -746,21 +960,23 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Divide(op::Exp(param0), op::Exp(param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Divide(m::Exp(m::Parameter(0)), m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Subtract(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Subtract(m::Parameter(0), m::Parameter(1))))); } // Test that exp(A)*exp(B) is simplified to exp(A+B) TEST_F(AlgebraicSimplifierTest, ExpMul) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -774,21 +990,22 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Exp(param0), op::Exp(param1))); + GmockMatch(m::Multiply(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Add(param0, param1))); + GmockMatch(m::Exp(m::Add(m::Parameter(0), m::Parameter(1))))); } // Test that pow(exp(A), B) is simplified to exp(A*B) TEST_F(AlgebraicSimplifierTest, PowExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -800,21 +1017,22 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Power(op::Exp(param0), param1)); + GmockMatch(m::Power(m::Exp(m::Parameter(0)), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Exp(op::Multiply(param0, param1))); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Exp(m::Multiply(m::Parameter(0), m::Parameter(1))))); } // Test that ln(pow(A, B)) is simplified to ln(A)*B TEST_F(AlgebraicSimplifierTest, LnPow) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -826,21 +1044,22 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Power(param0, param1))); + GmockMatch(m::Log(m::Power(m::Parameter(0), m::Parameter(1))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Log(param0), param1)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Log(m::Parameter(0)), m::Parameter(1)))); } // Test that ln(exp(A)) is simplified to A TEST_F(AlgebraicSimplifierTest, LnExp) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -850,19 +1069,20 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Log(m::Exp(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B TEST_F(AlgebraicSimplifierTest, LnExpDiv) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -878,21 +1098,23 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + GmockMatch(m::Log(m::Divide(m::Exp(m::Parameter(0)), + m::Exp(m::Parameter(1)))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Subtract(m::Parameter(0), m::Parameter(1)))); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar // constant 1. TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -902,21 +1124,22 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_EQ(root->literal().GetFirstElement(), 1); } // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). TEST_F(AlgebraicSimplifierTest, Pow0Vector) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -926,16 +1149,16 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast()); + EXPECT_THAT(root, GmockMatch(m::Broadcast())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -945,6 +1168,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { // Test that pow(A, 1) is simplified to A. TEST_F(AlgebraicSimplifierTest, Pow1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -954,19 +1178,20 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. TEST_F(AlgebraicSimplifierTest, Pow2) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -976,19 +1201,21 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(two)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } // Test that pow(A, -1) is simplified to 1/A. TEST_F(AlgebraicSimplifierTest, PowNegative1) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -998,22 +1225,23 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(negative_one)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_THAT(root, GmockMatch(m::Divide(m::Broadcast(), m::Parameter(0)))); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), 1); } TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); @@ -1046,17 +1274,17 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Convolution(lhs, rhs)); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Convolution(m::Op().Is(lhs), m::Op().Is(rhs)))); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1081,24 +1309,24 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } builder.AddInstruction(HloInstruction::CreateReduceWindow( ShapeUtil::MakeShape(F32, {5, 2}), param, builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), window, add_computation)); - module().AddEntryComputation(builder.Build()); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::ReduceWindow(param, op::Constant())); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + m->AddEntryComputation(builder.Build()); + HloPassFix simplifier(default_options_); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ReduceWindow(m::Parameter(0), m::Constant()))); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1115,17 +1343,17 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))), padding)); - module().AddEntryComputation(builder.Build()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Pad(param, op::Constant())); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Broadcast(op::Constant())); + m->AddEntryComputation(builder.Build()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Constant()))); + HloPassFix simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Constant()))); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -1139,39 +1367,40 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - module().AddEntryComputation(std::move(computation)); + m->AddEntryComputation(std::move(computation)); - EXPECT_THAT(module().entry_computation()->root_instruction(), - op::Reshape(op::Broadcast(op::Reshape(op)))); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Reshape(m::Broadcast(m::Reshape(m::Op().Is(op)))))); - HloPassFix simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloPassFix simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(module().entry_computation()->root_instruction(), op); + EXPECT_THAT(m->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert(m::Op().Is(input)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } // Test that copies are removed. TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + auto m = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1179,18 +1408,19 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1201,24 +1431,30 @@ TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) { ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param)); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 2, 0, 3}); - auto computation = module().AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + auto computation = m->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier1(options); + ASSERT_FALSE(simplifier1.Run(m.get()).ValueOrDie()); // Verify that the copy is not replaced. - EXPECT_THAT(computation->root_instruction(), op::Copy(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options2(bitcasting_callback()); + options2.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier2(options2); + ASSERT_TRUE(simplifier2.Run(m.get()).ValueOrDie()); // Verify that the copy is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } // Test that unary concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( @@ -1226,19 +1462,20 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } // Test that empty operands of concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1255,22 +1492,24 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT( - computation->root_instruction(), - op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Concatenate( + m::Op().Is(empty_literal), m::Parameter(0), m::Parameter(0), + m::Op().Is(empty_slice), m::Parameter(1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(param0, param0, param1)); + GmockMatch(m::Concatenate(m::Parameter(0), m::Parameter(0), + m::Parameter(1)))); } // Test that reduce of concat is simplified. TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r3f32 = ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength}); @@ -1296,7 +1535,7 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength}); @@ -1306,20 +1545,21 @@ TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) { builder.AddInstruction(HloInstruction::CreateReduce( reduce_shape, Concatenate, zero, {1, 2}, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), - op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)), - op::Reduce(param2, zero))); + GmockMatch(m::Map(m::Map(m::Reduce(m::Parameter(0), m::Op().Is(zero)), + m::Reduce(m::Parameter(1), m::Op().Is(zero))), + m::Reduce(m::Parameter(2), m::Op().Is(zero))))); } // Test a concatenate with only empty operands is removed. TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + auto m = CreateNewVerifiedModule(); const int kParamLength = 100; Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); @@ -1334,20 +1574,21 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Concatenate(empty_literal, empty_slice)); + GmockMatch(m::Concatenate(m::Op().Is(empty_literal), + m::Op().Is(empty_slice)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that concat with a scalar broadcast becomes a pad. TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); @@ -1360,17 +1601,88 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) { + auto m = CreateNewVerifiedModule(); + Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99}); + Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + + HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0}, + /*limit_indices=*/{50, 10}, /*strides=*/{1, 1})); + + // Cannot merge 'slice0' and 'slice1' because of different start indices in + // dimension 0. + HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10}, + /*limit_indices=*/{100, 20}, /*strides=*/{1, 1})); + + // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2. + HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20}, + /*limit_indices=*/{100, 40}, /*strides=*/{1, 2})); + + // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2. + HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40}, + /*limit_indices=*/{100, 50}, /*strides=*/{1, 1})); + + // Can merge 'slice3' and 'slice4'. + HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50}, + /*limit_indices=*/{100, 60}, /*strides=*/{1, 1})); + + // Can merge 'slice4' and 'slice5'. + HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60}, + /*limit_indices=*/{100, 70}, /*strides=*/{1, 1})); + + // Cannot merge 'slice5' and 'slice6' because of overlap. + HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69}, + /*limit_indices=*/{100, 79}, /*strides=*/{1, 1})); + + // Cannot merge 'slice6' and 'slice7' because of slicing from a different + // parameter. + HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79}, + /*limit_indices=*/{100, 89}, /*strides=*/{1, 1})); + + builder.AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, + {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1)); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + auto s = m::Slice(m::Parameter(0)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Concatenate(s, s, s, s, s, m::Slice(m::Parameter(1))))); + // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its + // shape should have dimensions {50, 30}. + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(), + ShapeUtil::MakeShape(F32, {50, 30}))); + EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40); } // Test that a simplification which changes layouts is not performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1378,25 +1690,29 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } // Test that a simplification which preserves layouts is performed if layout // sensitive is true. TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1404,17 +1720,19 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1423,6 +1741,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { // Test that a reshape which could be replaced with a bitcast is not if // add_bitcasts is false. TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1435,20 +1754,24 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } // Test transforming reshapes and transposes of rng. TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); @@ -1465,21 +1788,22 @@ TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { ShapeUtil::MakeShape(F32, {4}), transpose)) ->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // Verify that reshape(transpose(rng)) is replace by a single rng of the // same shape as the reshape. - EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Rng())); EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), reshape_shape)); } // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1511,25 +1835,29 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(transformable_reshape, dimensions_wrong_reshape, - layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Op().Is(transformable_reshape), + m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - simplifier.Run(&module()).ValueOrDie(); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + simplifier.Run(m.get()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( computation->root_instruction(), - op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); + GmockMatch(m::Tuple(m::Bitcast(), m::Op().Is(dimensions_wrong_reshape), + m::Op().Is(layout_wrong_reshape)))); } // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1544,15 +1872,16 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the // 'changed' bit in AlgebraicSimplifier to false. TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // This add (param0 + 0) can be simplified. @@ -1568,13 +1897,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0, 1})); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier( + (AlgebraicSimplifierOptions(bitcasting_callback()))); + m->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1588,19 +1918,23 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1614,19 +1948,23 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Bitcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1639,19 +1977,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Reshape(param0))); + GmockMatch(m::Reshape(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, CopiesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1666,18 +2005,22 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Copy(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1690,21 +2033,23 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Op().Is(transpose1)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Transpose(m::Parameter(0)))); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); } // Test merging reshape and broadcast. TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {5}), "param0")); @@ -1713,20 +2058,21 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Broadcast(op::Reshape(param0))); + GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } // Test merging broadcast and reshape. TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); @@ -1735,19 +2081,20 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param0))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1756,20 +2103,20 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1778,21 +2125,22 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(computation->root_instruction()->dimensions(), ::testing::ElementsAre(3)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1}), "param")); @@ -1801,16 +2149,16 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Parameter(0)))); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); @@ -1818,6 +2166,7 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {4}), "param")); @@ -1826,115 +2175,119 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Broadcast(param))); + GmockMatch(m::Reshape(m::Broadcast(m::Parameter(0))))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction(HloInstruction::CreateIota( ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); auto result_shape = iota->shape(); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); auto root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); EXPECT_EQ(Cast(computation->root_instruction()) ->iota_dimension(), 3); } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Iota())); const int64 iota_dim = Cast(computation->root_instruction()) ->iota_dimension(); @@ -1942,21 +2295,23 @@ TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { } TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto iota = builder.AddInstruction( HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + EXPECT_FALSE(simplifier.Run(m.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Iota()))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -1976,14 +2331,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2009,11 +2364,10 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); auto has_negative_padding = [](const HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { @@ -2025,16 +2379,54 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Pad(m::Parameter(0), m::Op().Is(zero))))); EXPECT_FALSE( has_negative_padding(computation->root_instruction()->operand(0))); } +TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) { + // Verify that a pad instruction with interior padding on one-sized + // dimensions, removes the interior padding. + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {2, 1}), "param")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + auto dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(3); + dimension->set_edge_padding_high(3); + dimension->set_interior_padding(i * 3); + } + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(default_options_); + + ASSERT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + ASSERT_TRUE(HasInteriorPadding(pad->padding_config())); + + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Pad(m::Parameter(0), m::Op().Is(zero)))); + EXPECT_FALSE( + HasInteriorPadding(computation->root_instruction()->padding_config())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2043,14 +2435,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2066,14 +2458,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -2095,22 +2487,75 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice, /*start_indices=*/{2, 3}, /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param))); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Slice(m::Parameter(0))))); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3); EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5); EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2); EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeToReshapeOfSlice) { + HloComputation::Builder builder(TestName()); + const int64 dim0 = 11; + const int64 dim1 = 12; + const int64 dim2 = 13; + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {dim0 * dim1, dim2}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {dim0, dim1, dim2}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {dim0 - 2, dim1, dim2}), original_reshape, + /*start_indices=*/{0, 0, 0}, + /*limit_indices=*/{dim0 - 2, dim1, dim2}, /*strides=*/{1, 1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Reshape(m::Slice(m::Parameter(0))))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfReshapeUnchanged) { + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 144, 25, 1, 512}), "param")); + HloInstruction* original_reshape = + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {3600, 512}), param)); + + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {960, 512}), original_reshape, + /*start_indices=*/{0, 0}, + /*limit_indices=*/{960, 512}, /*strides=*/{1, 1})); + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Slice(m::Reshape(m::Parameter(0))))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto builder = HloComputation::Builder(TestName()); @@ -2118,14 +2563,86 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) { auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), keys); } +TEST_F(AlgebraicSimplifierTest, ReplacePermutationSortWithScatter) { + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + GmockMatch(m::Tuple( + m::Iota(), + m::Scatter(m::Iota(), m::Concatenate(m::Iota(), m::Reshape()), + m::Reshape())))); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortIfNonIntegral) { + // Same as ReplacePermutationSortWithScatter except that the iota has F32 + // type. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = f32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = f32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (f32[64,8732]{1,0}, f32[64,8732]{1,0}) sort(gte, values), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, DontReplacePermutationSortWrongDimensions) { + // Same as ReplacePermutationSortWithScatter except that the sort dimensions + // don't match. + const char* hlo_string = R"( + HloModule permutation_sort + + ENTRY sort_computation { + keys = f32[64,8732]{1,0} parameter(0) + values = s32[64,8732]{1,0} iota(), iota_dimension=1 + sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1} + gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1 + ROOT sort2 = (s32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(gte, values), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(non_bitcasting_callback()); + options.set_enable_permutation_sort_replacement(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { auto builder = HloComputation::Builder(TestName()); @@ -2133,16 +2650,188 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto values = builder.AddInstruction( - HloInstruction::CreateParameter(1, values_shape, "values")); + auto values0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values0")); + auto values1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, values_shape, "values1")); builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); - auto module = CreateNewModule(); + ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, + keys, {values0, values1})); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Tuple(m::Op().Is(keys), m::Op().Is(values0), + m::Op().Is(values1)))); +} + +// Test that A && True is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that True && A is simplified to A +TEST_F(AlgebraicSimplifierTest, AndTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that A && False is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that False && A is simplified to False +TEST_F(AlgebraicSimplifierTest, AndFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kAnd, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAnd); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_false); +} + +// Test that A || True is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, param0, const_true)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that True || A is simplified to True +TEST_F(AlgebraicSimplifierTest, OrTrue2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_true = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, const_true, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, const_true); +} + +// Test that A || False is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + param0, const_false)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that False || A is simplified to A +TEST_F(AlgebraicSimplifierTest, OrFalse2) { + auto m = CreateNewVerifiedModule(); + Shape r0pred = ShapeUtil::MakeShape(PRED, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0pred, "param0")); + HloInstruction* const_false = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateBinary(r0pred, HloOpcode::kOr, + const_false, param0)); + + auto computation = m->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kOr); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); } // Used for TEST_Ps that test merging (or not) of a kPad instruction into a @@ -2266,18 +2955,18 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(), lhs_pad, filter, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrCat("size=3x3 ", testcase.expected_conv_window)); } @@ -2384,18 +3073,18 @@ TEST_P(ConvFilterPaddingTest, DoIt) { input, rhs_pad, /*feature_group_count=*/1, window, dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); if (testcase.expected_conv_window.empty()) { - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); } else { - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); auto* conv = module->entry_computation()->root_instruction(); SCOPED_TRACE(module->ToString()); - ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + ASSERT_THAT(conv, + GmockMatch(m::Convolution(m::Parameter(), m::Parameter()))); EXPECT_EQ(window_util::ToString(conv->window()), absl::StrFormat("size=%dx%d %s", conv->operand(1)->shape().dimensions(2), @@ -2533,11 +3222,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto* computation = module->AddEntryComputation(b.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, - bitcasting_callback()); + AlgebraicSimplifierOptions simplifier_options(bitcasting_callback()); + simplifier_options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(simplifier_options); if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } @@ -2653,24 +3343,22 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(scalar_param)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(scalar_param)) + .WithShapeEqualTo(&slice_shape))); } // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a @@ -2692,26 +3380,24 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - - root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(forty_two)); - EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Broadcast(m::Op().Is(forty_two)) + .WithShapeEqualTo(&reshape_shape))); } // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2766,8 +3452,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -2775,7 +3460,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_THAT(root, + GmockMatch(m::ReduceWindow(m::Op().Is(operand), m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -2793,7 +3479,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2852,8 +3538,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); + AlgebraicSimplifier simplifier(default_options_); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. @@ -2861,7 +3546,8 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { // Verify the result root = computation->root_instruction(); - EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_THAT(root, GmockMatch(m::ReduceWindow(m::Convert(m::Parameter(0)), + m::Constant()))); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) << ShapeUtil::HumanString(root->shape()) << " vs " << ShapeUtil::HumanString(reduce_window_shape); @@ -2883,12 +3569,11 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2899,6 +3584,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { // Dots add computations to the parent module. Test that, when the HloModule's // computations are updated, then iterator invalidation doesn't occur // when running on subsequent computations. + auto m = CreateNewVerifiedModule(); Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); HloComputation::Builder builder(TestName() + ".Dot"); HloInstruction* x = @@ -2920,15 +3606,15 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - module().AddEmbeddedComputation(std::move(dot_computation)); - module().AddEntryComputation(call_builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + m->AddEmbeddedComputation(std::move(dot_computation)); + m->AddEntryComputation(call_builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -2937,19 +3623,19 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); + GmockMatch(m::Tuple(m::Constant(), m::Constant()))); } // A dynamic-slice is trivial if its start indices are all zeroes and the size // of its input equals the size of its output. In this case, the dynamic slice // is equal to its input. TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -2961,17 +3647,17 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), /*slice_sizes=*/{10, 100, 1000})); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Parameter()); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Parameter())); } // A dynamic-update-slice is trivial if its start indices are all zeroes and the // size of its "update" equals the size of its output. In this case, the // dynamic-update-slice is equal to its update. TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -2994,16 +3680,16 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction(HloInstruction::CreateParameter( 3, ShapeUtil::MakeShape(U32, {3}), "update_indices")))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Parameter(), op::Parameter())); + GmockMatch(m::DynamicSlice(m::Parameter(), m::Parameter()))); } // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* input_array = builder.AddInstruction( @@ -3014,19 +3700,19 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { builder.AddInstruction( HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Constant()))); EXPECT_THAT(root->dimensions(), ElementsAre(2)); } // Test that two consecutive broadcasts can be merged to one. TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); @@ -3040,19 +3726,19 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Parameter(0)))); EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* iota = @@ -3060,19 +3746,19 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } // Test that a broadcast of an iota can be merged to one iota. TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); HloInstruction* iota = @@ -3081,17 +3767,184 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { builder.AddInstruction( HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_THAT(root, op::Iota()); + EXPECT_THAT(root, GmockMatch(m::Iota())); EXPECT_EQ(Cast(root)->iota_dimension(), 2); } +TEST_F(AlgebraicSimplifierTest, SliceOfPadLow) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[2:3],[0:1]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadHigh) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[6:7],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Reshape(m::Constant()))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidNonScalar) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[3,4] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[3,4] param, f32[] constant), padding=3_2x1_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[5:6],[9:10]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfPadMidScalar) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param = f32[1,1] parameter(0) + constant = f32[] constant(0.0) + pad = f32[8,10] pad(f32[1,1] param, f32[] constant), padding=3_4x4_5 + ROOT slice = f32[1,1] slice(f32[8,10] pad), slice={[3:4],[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter())); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfConcatScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[2:3]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter(1))); +} + +TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + param.1 = f32[1] parameter(1) + param.2 = f32[3] parameter(2) + concat = f32[6] concatenate(param.0, param.1, param.2), dimensions={0} + ROOT slice = f32[1] slice(concat), slice={[4:5]} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Slice(m::Parameter(2)))); + EXPECT_EQ(root->slice_starts(0), 1); + EXPECT_EQ(root->slice_limits(0), 2); +} + +TEST_F(AlgebraicSimplifierTest, NegateNegate) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = f32[2] parameter(0) + neg.0 = f32[2] negate(param.0) + ROOT neg.1 = f32[2] negate(neg.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); +} + +TEST_F(AlgebraicSimplifierTest, NotNot) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + param.0 = pred[2] parameter(0) + not.0 = pred[2] not(param.0) + ROOT not.1 = pred[2] not(not.0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options(bitcasting_callback()); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Parameter(0))); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -3121,6 +3974,7 @@ class PadReduceWindowEffectiveBroadcastTest PadReduceWindowEffectiveBroadcastCase> {}; TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + auto m = CreateNewVerifiedModule(); const auto& param = GetParam(); // a and b are parallel bounds we can either turn into a B F S0 S1 or @@ -3169,7 +4023,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module().AddEmbeddedComputation(builder.Build()); + add_computation = m->AddEmbeddedComputation(builder.Build()); } Window window = window_util::MakeWindow( @@ -3183,20 +4037,19 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { builder.AddInstruction(HloInstruction::CreateReduceWindow( output_shape, pad, zero, window, add_computation)); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); if (param.should_become_broadcast) { - EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Broadcast())); } else { EXPECT_THAT(computation->root_instruction(), - op::ReduceWindow(::testing::_, zero)); + GmockMatch(m::ReduceWindow(m::Op(), m::Op().Is(zero)))); } } @@ -3235,6 +4088,7 @@ class DotStrengthReductionTest public ::testing::WithParamInterface< ::testing::tuple> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + auto module = CreateNewVerifiedModule(); int m, k, n; bool transpose_lhs, transpose_rhs; PrimitiveType element_type; @@ -3264,10 +4118,9 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { dot_dnums.add_rhs_contracting_dimensions(0); builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; const bool computation_should_be_modified = dot_should_be_transformed || (transpose_lhs && transpose_rhs); @@ -3295,7 +4148,7 @@ struct DotOfConcatTestSpec { }; class DotOfConcatSimplificationTest - : public HloVerifiedTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // Test that we transform @@ -3303,6 +4156,7 @@ class DotOfConcatSimplificationTest // to // add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3341,20 +4195,20 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); - auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); - auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); + auto match_dot_0 = m::Dot(m::Slice(m::Constant()), m::Parameter(0)); + auto match_dot_1 = m::Dot(m::Slice(m::Constant()), m::Parameter(1)); + auto match_dot_2 = m::Dot(m::Slice(m::Constant()), m::Parameter(2)); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2))); } // Test that we transform @@ -3362,6 +4216,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { // to // add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfConcatTestSpec spec = GetParam(); @@ -3405,21 +4260,21 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); - auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); - auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); - auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); - auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); - EXPECT_THAT(computation->root_instruction(), - op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), - match_dot_3)); + auto match_dot_0 = m::Dot(m::Parameter(0), m::Slice(m::Constant())); + auto match_dot_1 = m::Dot(m::Parameter(1), m::Slice(m::Constant())); + auto match_dot_2 = m::Dot(m::Parameter(2), m::Slice(m::Constant())); + auto match_dot_3 = m::Dot(m::Parameter(3), m::Slice(m::Constant())); + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Add(m::Add(m::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3))); } DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { @@ -3433,6 +4288,7 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { // Test that DynamicUpdateSlice update param with any dimension equal to zero // gets removed. TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); HloInstruction* const operand = builder.AddInstruction( @@ -3445,11 +4301,10 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( dslice_shape, operand, update, start_indices)); const HloComputation* const computation = - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), operand); } @@ -3468,7 +4323,7 @@ struct DotOfGatherTestSpec { }; class DotOfGatherSimplificationTest - : public HloVerifiedTestBase, + : public AlgebraicSimplifierTest, public ::testing::WithParamInterface {}; // input: dot(DS(ctA), ctB)) @@ -3477,6 +4332,7 @@ class DotOfGatherSimplificationTest // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3523,10 +4379,9 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3536,8 +4391,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } @@ -3547,6 +4402,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { // output: DS(dot(ctA, ctB)) // => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); DotOfGatherTestSpec spec = GetParam(); @@ -3593,10 +4449,9 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); - auto computation = module().AddEntryComputation(builder.Build()); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + auto computation = m->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(default_options_); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(m.get())); ASSERT_TRUE(run_successful); EXPECT_TRUE( ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); @@ -3606,8 +4461,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { HloOpcode::kDynamicSlice); } else { EXPECT_THAT(computation->root_instruction(), - op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), - op::Concatenate())); + GmockMatch(m::DynamicSlice(m::Dot(m::Constant(), m::Constant()), + m::Concatenate()))); } } diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f..ef5e211646e7b0b66b8e6c09948be58063422943 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -176,13 +176,13 @@ StatusOr> AllocationTracker::DeconstructTuple( } StatusOr> AllocationTracker::Resolve( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } StatusOr AllocationTracker::ResolveForReplica( - const GlobalDataHandle& data, int replica_id) { + const GlobalDataHandle& data, int replica_id) const { tensorflow::mutex_lock lock(mutex_); TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, ResolveInternal(data)); @@ -196,7 +196,7 @@ StatusOr AllocationTracker::ResolveForReplica( } StatusOr> AllocationTracker::ResolveInternal( - const GlobalDataHandle& data) { + const GlobalDataHandle& data) const { VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 43feccee3c67152c6f61098bb98d546379848b8c..98d1a302a9f66f4a00e05d62837a79133e222687 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -65,13 +65,13 @@ class AllocationTracker { // replica, or provide an error status to say whether any of those buffers // were not found (or found, but found deallocated). StatusOr> Resolve( - const GlobalDataHandle& data); + const GlobalDataHandle& data) const; // Resolves a handle from an XLA client and replica id to a shaped buffer, or // provide an error status to say whether it was not found (or found, but // found deallocated). StatusOr ResolveForReplica(const GlobalDataHandle& data, - int replica_id); + int replica_id) const; private: // Data structure encapsulating single memory allocation on the device. @@ -87,7 +87,7 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // list of ScopedShapedBuffers. StatusOr> ResolveInternal( - const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + const GlobalDataHandle& data) const EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If @@ -113,7 +113,7 @@ class AllocationTracker { // maintained per device ordinal. using AllocationMap = absl::flat_hash_map; - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; // Backend to use with this tracker. The backend supplies the memory allocator // to use when deallocating memory. diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc new file mode 100644 index 0000000000000000000000000000000000000000..362bc44a1cf377b51c5519c6ab5e0d9628e80e58 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc @@ -0,0 +1,285 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace m = match; + +// If the argument instruction is a CRS in the sequence +// AR -> Convert -> Add -> CRS +// then return the AR in the sequence. +// TODO(b/117554291): Rewrite this to recognize more general patterns, +// not just the specific one of AR -> Add -> Convert -> CRS. +absl::optional MatchesArCrsPattern( + HloInstruction* instruction) { + HloInstruction *ar, *convert, *add, *crs; + if (Match(instruction, + m::CrossReplicaSum( + &crs, m::Add(&add, m::Op(), + m::Convert(&convert, + m::CrossReplicaSum(&ar, m::Op()))))) && + ar->users().size() == 1 && ar->shape().element_type() == BF16 && + convert->shape().element_type() == F32 && !crs->all_reduce_id()) { + return ar; + } + return absl::optional(); +} + +} // namespace + +absl::optional ArCrsCombiner::WhileFromBodyParameter( + HloInstruction* instruction) { + CHECK_EQ(HloOpcode::kParameter, instruction->opcode()); + HloComputation* computation = instruction->parent(); + auto caller_instructions = call_graph_->GetComputationCallers(computation); + if (caller_instructions.size() == 1) { + auto caller_instruction = caller_instructions[0]; + if (caller_instruction->opcode() == HloOpcode::kWhile) { + return caller_instruction; + } + } + return absl::optional(); +} + +std::vector ArCrsCombiner::GetAllTuples( + HloInstruction* instruction) { + if (instruction->opcode() == HloOpcode::kTuple) { + return {instruction}; + } + if (instruction->opcode() == HloOpcode::kDomain) { + return GetAllTuples(instruction->operands()[0]); + } + if (instruction->opcode() == HloOpcode::kParameter) { + auto maybe_while = WhileFromBodyParameter(instruction); + if (!maybe_while) { + return {}; + } + auto while_instr = *maybe_while; + auto init_tuples = GetAllTuples(while_instr->while_init()); + auto body_tuples = + GetAllTuples(while_instr->while_body()->root_instruction()); + if (init_tuples.empty() || body_tuples.empty()) { + return {}; + } + init_tuples.insert(init_tuples.end(), body_tuples.begin(), + body_tuples.end()); + return init_tuples; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + std::vector result_tuples; + for (auto tuple : GetAllTuples(instruction->operands()[0])) { + auto tmp_tuples = + GetAllTuples(tuple->mutable_operand(instruction->tuple_index())); + if (tmp_tuples.empty()) { + return {}; + } + result_tuples.insert(result_tuples.end(), tmp_tuples.begin(), + tmp_tuples.end()); + } + return result_tuples; + } + return {}; +} + +bool ArCrsCombiner::TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs) { + auto tuples = GetAllTuples(tuple_shaped_instruction); + if (tuples.empty()) { + return false; + } + for (auto tuple : tuples) { + CHECK_EQ(tuple->opcode(), HloOpcode::kTuple); + if (!InstructionsComputeSameValue(tuple->mutable_operand(i1), + tuple->mutable_operand(i2), + visited_pairs)) { + return false; + } + } + return true; +} + +/* static */ +bool ArCrsCombiner::TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2) { + ArCrsCombiner combiner(/*num_spatial_partitions=*/2); + auto module = i1->parent()->parent(); + CHECK_EQ(module, i2->parent()->parent()); + combiner.call_graph_ = CallGraph::Build(module); + absl::flat_hash_map visited_pairs; + return combiner.InstructionsComputeSameValue(i1, i2, &visited_pairs); +} + +bool ArCrsCombiner::InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs) { + if (i1 == i2) { + return true; + } + auto uid1 = i1->unique_id(); + auto uid2 = i2->unique_id(); + auto min_uid = std::min(uid1, uid2); + auto max_uid = std::max(uid1, uid2); + auto it = visited_pairs->find(min_uid); + if (it != visited_pairs->end() && max_uid == it->second) { + return true; + } + auto opcode1 = i1->opcode(); + auto operands1 = i1->operands(); + if (opcode1 != i2->opcode() || operands1.size() != i2->operands().size()) { + return false; + } + visited_pairs->emplace(min_uid, max_uid); + for (int i = 0; i < operands1.size(); ++i) { + auto operand1 = operands1[i]; + auto operand2 = i2->operands()[i]; + if (!InstructionsComputeSameValue(operand1, operand2, visited_pairs)) { + return false; + } + } + if (opcode1 == HloOpcode::kParameter) { + // In the general case, we don't try to prove equality of parameters. + // We only try in the context of get-tuple-element + // (see TupleElementsComputeSameValue). + return false; + } + if (opcode1 == HloOpcode::kGetTupleElement) { + return i1->tuple_index() == i2->tuple_index() || + TupleElementsComputeSameValue(operands1[0], i1->tuple_index(), + i2->tuple_index(), visited_pairs); + } + // Don't check that the operands are identical, because Identical can + // return false for instructions that compute the same value but are not + // identical, which we don't want. We have checked the arguments with + // InstructionsComputeSameValue earlier. + auto eq_instructions = [](const HloInstruction* i1, + const HloInstruction* i2) -> bool { return true; }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + return i1->Identical(*i2, eq_instructions, eq_computations, + /*layout_sensitive=*/false); +} + +void ArCrsCombiner::GroupAllReducesById(HloModule* module) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + auto ar = MatchesArCrsPattern(instruction); + if (ar) { + all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar); + } + } + } +} + +void ArCrsCombiner::KeepProvablyEqualInstructionGroups() { + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + CHECK_EQ(instruction_vec.size(), num_spatial_partitions_); + + auto instr_0 = instruction_vec[0]; + auto add_0 = instr_0->users()[0]->users()[0]; + CHECK_EQ(HloOpcode::kAdd, add_0->opcode()); + + for (int i = 1; i < instruction_vec.size(); ++i) { + auto instr_i = instruction_vec[i]; + auto add_i = instr_i->users()[0]->users()[0]; + CHECK_EQ(HloOpcode::kAdd, add_i->opcode()); + absl::flat_hash_map visited_pairs; + if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) { + all_reduce_map_.erase(it.first); + } + } + } +} + +StatusOr ArCrsCombiner::RewriteGraph() { + if (all_reduce_map_.empty()) { + return false; + } + + auto computation_is_addition = [](HloComputation* c) { + return c->instruction_count() == 3 && + Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter())); + }; + + for (auto it : all_reduce_map_) { + auto instruction_vec = it.second; + for (auto all_reduce : instruction_vec) { + auto parent_computation = all_reduce->parent(); + auto convert = all_reduce->users()[0]; + auto add = convert->users()[0]; + auto crs = add->users()[0]; + + if (!computation_is_addition(all_reduce->called_computations()[0]) || + !computation_is_addition(crs->called_computations()[0])) { + continue; + } + HloInstruction* other_summand = (add->operands()[0] == convert) + ? add->operands()[1] + : add->operands()[0]; + // To move the AR past the addition, we need to divide other_summand by + // the number of spatial partitions. + CHECK_EQ(all_reduce->user_count(), 1); + TF_CHECK_OK( + all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0))); + auto shape = other_summand->shape(); + Literal lit(shape); + lit.PopulateWithValue(num_spatial_partitions_); + auto divisor = parent_computation->AddInstruction( + HloInstruction::CreateConstant(lit.Clone())); + auto division = + parent_computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, other_summand, divisor)); + TF_CHECK_OK(other_summand->ReplaceUseWith(add, division)); + // The AllReduce and the CRS are combined to an all-core AllReduce. + crs->set_all_reduce_id(all_reduce->all_reduce_id()); + TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce)); + } + } + + return true; +} + +StatusOr ArCrsCombiner::Run(HloModule* module) { + call_graph_ = CallGraph::Build(module); + + GroupAllReducesById(module); + + KeepProvablyEqualInstructionGroups(); + + return RewriteGraph(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.h b/tensorflow/compiler/xla/service/ar_crs_combiner.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a7ef76ec3b76972d1b2c7fb548cecfb9423160 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Combine an AllReduce and a CrossReplicaSum when they are close to each other +// in the graph, to use an efficient CrossReplicaSum implementation that +// fully utilizes the interconnect bandwidth. +class ArCrsCombiner : public HloModulePass { + public: + ArCrsCombiner(int num_spatial_partitions) + : num_spatial_partitions_(num_spatial_partitions) {} + absl::string_view name() const override { return "ar-crs-combiner"; } + StatusOr Run(HloModule* module) override; + + // Helper method to allow testing of InstructionsComputeSameValue. + static bool TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2); + + private: + // If the passed instruction is a while parameter, and the while body is only + // called by a single while instruction, return the while instruction. + absl::optional WhileFromBodyParameter( + HloInstruction* instruction); + + // Returns a vector of tuple instructions. + // If all instructions that flow to "instruction" are tuples, return them. + // Otherwise, return an empty vector. + std::vector GetAllTuples(HloInstruction* instruction); + + // Checks whether two different elements in the same tuple compute the same + // value. + bool TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2, + absl::flat_hash_map* visited_pairs); + + // Returns whether the instructions i1 and i2 can be shown to evaluate to the + // same value. Handling WHILE requires recursion, which may cause us to visit + // the same instruction again. To avoid infinite loops, we pass a cache of + // visited instruction pairs. + bool InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs); + + // Populates all_reduce_map_. + void GroupAllReducesById(HloModule* module); + + // Looks at each AllReduce group in all_reduce_map_, and keeps only the + // groups for which it's safe to move the AllReduce later in the HLO graph. + void KeepProvablyEqualInstructionGroups(); + + // Performs the graph rewrite that eliminates the early AllReduce and turns + // the later CRS into an AllReduce. + StatusOr RewriteGraph(); + + int num_spatial_partitions_; + + // Map from all-reduce ids to the all reduce instructions. + absl::flat_hash_map> all_reduce_map_; + + std::unique_ptr call_graph_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..10171835d83c75fef091a34b8fe102d263211307 --- /dev/null +++ b/tensorflow/compiler/xla/service/ar_crs_combiner_test.cc @@ -0,0 +1,496 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/ar_crs_combiner.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ArCrsCombinerTest : public HloTestBase {}; + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue( + i1, module->entry_computation()->parameter_instruction(0))); + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + ROOT %tuple = (f32[], f32[]) tuple(%x, %x) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestBasecase3) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (x: f32[], y: f32[]) -> (f32[], f32[]) { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %tuple = (f32[], f32[]) tuple(%x, %y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestNumOperands) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> ((f32[2,2]), (f32[2,2], f32[2,2])) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple1 = (f32[2,2]) tuple(%constant.f32) + %tuple2 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %tuple = ((f32[2,2]), (f32[2,2], f32[2,2])) tuple(%tuple1, %tuple2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[0:1]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestSliceIndicesDontMatch) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2]) -> (f32[1], f32[1]) { + %p = f32[2] parameter(0) + %slice.1 = f32[1] slice(f32[2] %p), slice={[0:1]} + %slice.2 = f32[1] slice(f32[2] %p), slice={[1:2]} + ROOT %tuple = (f32[1], f32[1]) tuple(%slice.1, %slice.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementSameIndex) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=0 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex1) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestTupleElementDifferentIndex2) { + const char* module_str = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{2, 3}, {4, 5}}) + %tuple.1 = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%tuple.1), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%tuple.1), index=1 + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%get-tuple-element.1, %get-tuple-element.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_tuple = module->entry_computation()->root_instruction(); + auto i1 = root_tuple->operands()[0]; + auto i2 = root_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile1) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_TRUE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile2) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {7, 8}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32.1, %constant.f32.2) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]; + auto i2 = body_tuple->operands()[1]; + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, SameValueTestWhile3) { + const char* module_str = R"( +HloModule foobar + +%condition (x: (f32[2,2], f32[2,2])) -> pred[] { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.0 = s32[] constant(0) + %constant.1 = s32[] constant(1) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %constant.0) +} + +%body (x: (f32[2,2], f32[2,2])) -> (f32[2,2], f32[2,2]) { + %x = (f32[2,2], f32[2,2]) parameter(0) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {1, 2}}) + %get-tuple-element.1 = f32[2,2] get-tuple-element(%x), index=0 + %get-tuple-element.2 = f32[2,2] get-tuple-element(%x), index=1 + %add.1 = f32[2,2] add(%get-tuple-element.1, %constant.f32.1) + %add.2 = f32[2,2] add(%get-tuple-element.2, %constant.f32.2) + ROOT %tuple = (f32[2,2], f32[2,2]) tuple(%add.1, %add.2) +} + +ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) { + %constant.f32 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + %init.tuple = (f32[2,2], f32[2,2]) tuple(%constant.f32, %constant.f32) + ROOT %while = (f32[2,2], f32[2,2]) while(%init.tuple), condition=%condition, body=%body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto root_while = module->entry_computation()->root_instruction(); + auto body_tuple = root_while->while_body()->root_instruction(); + auto i1 = body_tuple->operands()[0]->operands()[0]; // %get-tuple-element.1 + auto i2 = body_tuple->operands()[1]->operands()[0]; // %get-tuple-element.2 + EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2)); +} + +TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + auto crs_before = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_before = crs_before->replica_groups(); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple( + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())), + op::CrossReplicaSum(op::Add( + op::Divide(op::Constant(), op::Constant()), op::Convert())))); + auto crs_after = + module->entry_computation()->root_instruction()->operands()[0]; + auto replica_groups_after = crs_after->replica_groups(); + ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size()); + for (int i = 0; i < replica_groups_before.size(); ++i) { + // Somewhat verbose way to compare the replica_ids, because EqualsProto + // is not available in the open-source build. + auto group_before = replica_groups_before[i]; + std::vector ids_before(group_before.replica_ids().begin(), + group_before.replica_ids().end()); + auto group_after = replica_groups_after[i]; + std::vector ids_after(group_after.replica_ids().begin(), + group_after.replica_ids().end()); + EXPECT_EQ(ids_before, ids_after); + } +} + +TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) { + const char* module_str = R"( +HloModule foobar + +%binary_add (a: bf16[], b: bf16[]) -> bf16[] { + %a = bf16[] parameter(0) + %b = bf16[] parameter(1) + ROOT %add = bf16[] add(%a, %b) +} + +%sum.f32 (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) { + %p = f32[2,2] parameter(0) + %constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}}) + %constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + %constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}}) + + %cross-replica-sum.ar.1 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=0} + %convert.1 = f32[2,2] + convert(%cross-replica-sum.ar.1), + sharding={maximal device=0} + %add.1 = f32[2,2] + add(%constant.f32.1, %convert.1), + sharding={maximal device=0} + %cross-replica-sum.1 = f32[2,2] + cross-replica-sum(%add.1), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=0} + + %cross-replica-sum.ar.2 = bf16[2,2] + cross-replica-sum(%constant.bf16), + replica_groups={{0},{1}}, + all_reduce_id=1, + to_apply=%binary_add, + sharding={maximal device=1} + %convert.2 = f32[2,2] + convert(%cross-replica-sum.ar.2), + sharding={maximal device=1} + %add.2 = f32[2,2] + add(%constant.f32.2, %convert.2), + sharding={maximal device=1} + %cross-replica-sum.2 = f32[2,2] + cross-replica-sum(%add.2), + replica_groups={{0,1}}, + to_apply=%sum.f32, + sharding={maximal device=1} + + ROOT %tuple = (f32[2,2], f32[2,2]) + tuple(%cross-replica-sum.1, %cross-replica-sum.2), + sharding={{maximal device=0}, {maximal device=1}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + ArCrsCombiner combiner(2); + auto changed = combiner.Run(module.get()).ValueOrDie(); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a645f98220ec445bb9bbdf2b9b842109..52ec1a794c5e9f4452a4bf2b648f453d8acfe976 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -17,14 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { @@ -38,11 +37,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -61,11 +61,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -84,11 +85,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -107,11 +109,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -130,11 +133,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), @@ -153,11 +157,12 @@ main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); BatchDotSimplification pass; - ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + ASSERT_TRUE(pass.Run(m.get()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Reshape(op::Dot( op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index f70f6ddfec69c0113a1afe2073a2392098f49456..0e6ca1871b379a2f55b92207133822fc6258b007 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -107,19 +107,37 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } std::unique_ptr Mean( - int64 element_count, HloInstruction* operand, + HloInstruction* element_count, HloInstruction* operand, const std::function)>& add_instruction) { - HloInstruction* elem_count_recip = - add_instruction(HloInstruction::CreateBroadcast( - operand->shape(), - add_instruction(HloInstruction::CreateConvert( - ShapeUtil::MakeShape(operand->shape().element_type(), {}), - add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(1.0 / element_count))))), - {})); - return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, - operand, elem_count_recip); + auto broadcast = add_instruction( + HloInstruction::CreateBroadcast(operand->shape(), element_count, {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kDivide, + operand, broadcast); + } + + std::unique_ptr DynamicElementCountPerFeature( + HloInstruction* operand, int64 feature_index, + const std::function)>& + add_instruction) { + auto elements_per_feature_u32 = add_instruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + + for (int64 i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (i == feature_index) { + continue; + } + auto dynamic_dimension_size = + add_instruction(HloInstruction::CreateGetDimensionSize( + ShapeUtil::MakeShape(U32, {}), operand, i)); + elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply, + dynamic_dimension_size, elements_per_feature_u32)); + } + + return HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + elements_per_feature_u32); } // Replaces the existing HLO instruction old_instruction, with @@ -195,9 +213,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape operand_shape = operand->shape(); PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); - const int64 feature_count = operand_shape.dimensions(feature_index); - const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -220,6 +235,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } } + auto elements_per_feature = + add(DynamicElementCountPerFeature(operand, feature_index, add)); + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); @@ -243,13 +261,13 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_reduce_computation)); // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum, add)); + auto mean = add(Mean(elements_per_feature, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); + auto square_mean = add(Mean(elements_per_feature, squared_sum, add)); // E^2[X]. auto mean_square = @@ -458,9 +476,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( int64 feature_index = batch_norm->feature_index(); - const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); - const int64 feature_count = activation_shape.dimensions(feature_index); - const int64 elements_per_feature_int64 = size_in_elements / feature_count; + auto elements_per_feature = + add(DynamicElementCountPerFeature(activation, feature_index, add)); auto zero_literal = LiteralUtil::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); @@ -553,15 +570,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add( - Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature, scale_times_rsqrt_var_add_epsilon, add)); - auto elements_per_feature_literal = - LiteralUtil::CreateR0(elements_per_feature_int64); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal.Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, add(HloInstruction::CreateBroadcast( activation_shape, elements_per_feature, {}))); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f7ac8f5482908af104554a1cf812370b9098cda7..8e8fbbd935b154e5a77d68e60d861601d740bf03 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloVerifiedTestBase; +class BatchNormExpanderTest : public HloTestBase { + protected: + // BatchNorm should have a dynamic sized dividor for mean operations. + int64 CountGetDimensionSize(const HloModule& module) { + int64 count = 0; + for (HloComputation* comp : module.computations()) { + for (HloInstruction* inst : comp->instructions()) { + if (inst->opcode() == HloOpcode::kGetDimensionSize) { + count++; + } + } + } + return count; + } +}; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -59,15 +73,16 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { param0, param1, param2, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -101,15 +116,16 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { param1, param2, param3, param4, /*epsilon=*/0.001, /*feature_index=*/3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); + EXPECT_EQ(CountGetDimensionSize(*module), 3); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); } @@ -126,13 +142,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(m.get()).ValueOrDie()); - for (auto* instruction : module().entry_computation()->instructions()) { + for (auto* instruction : m->entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index d63287539dfde5bb4890ab8303ef2205133d8125..e9d30fc03c1c3194de577e6683b36a95641694d9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -151,15 +151,10 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { // Do not fold BF16 conversions for instructions related to tuples, entry and - // exit of a computation, fusion, convert, and control flow. + // exit of a computation, fusion, convert, side-effecting instructions and + // control flow. if (hlo->opcode() == HloOpcode::kTuple || // hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // - hlo->opcode() == HloOpcode::kSend || // - hlo->opcode() == HloOpcode::kSendDone || // - hlo->opcode() == HloOpcode::kRecv || // - hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -167,7 +162,8 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kCall || // hlo->opcode() == HloOpcode::kCustomCall || // hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional) { + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse()) { return Status::OK(); } if (hlo == computation_->root_instruction() && @@ -182,6 +178,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { + if (crs->IsCrossModuleAllReduce()) { + // Cross-module all-reduce has side effect. + return Status::OK(); + } // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 5f93740887aa7e61458990992fe0573883ff056d..4ce351acc2c359773e618da70360c96faf5ca379 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,11 +65,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { +class BFloat16ConversionFoldingTest : public HloTestBase { protected: BFloat16ConversionFoldingTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -103,10 +103,10 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -138,10 +138,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -173,10 +173,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { HloInstruction* convert2 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -203,10 +203,10 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { HloInstruction* convert1 = builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module)); + EXPECT_FALSE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -216,7 +216,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("add"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -252,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module)); + EXPECT_TRUE(FoldConversions(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index d5b1148058898596bfdb837826a590bbc74e202a..b8a8f844eff17a95d4073f53495e0027c481f558 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -231,6 +231,10 @@ Status BFloat16NormalizationVisitor::HandleMultipleOutputs( for (auto* user : materialized_users) { TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple)); } + bool is_root = computation_->root_instruction() == hlo; + if (is_root) { + computation_->set_root_instruction(tuple); + } *tuple->mutable_shape() = original_shape; return Status::OK(); } @@ -342,11 +346,9 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { // Do not change instructions related to entry and exit of a computation, - // tuples, fusion, convert, and control flow. + // tuples, fusion, convert, side-effecting instructions, and control flow. if (hlo->opcode() == HloOpcode::kTuple || // hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -354,7 +356,8 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kCall || // hlo->opcode() == HloOpcode::kCustomCall || // hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional) { + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse()) { return Status::OK(); } // TODO(b/112040122): Correctly normalize variadic reduce. diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cef0eba14e9dd463d6c32b047211bf25a84478f6..9f97d18c565c7915b9f9346f0c6330cdc3c707e9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,11 +68,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloVerifiedTestBase { +class BFloat16NormalizationTest : public HloTestBase { protected: BFloat16NormalizationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; @@ -106,10 +106,10 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module)); + EXPECT_FALSE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -134,10 +134,10 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { HloInstruction* mul1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -164,10 +164,10 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { HloInstruction* sub1 = builder.AddInstruction( HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -191,7 +191,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, reduce_comp_param0, reduce_comp_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(reduce_comp_builder.Build()); @@ -205,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -233,7 +233,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder sum_builder("sum"); auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -263,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -272,7 +272,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); @@ -284,13 +284,13 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction::CreateParameter(1, s32_shape, "value")); HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -298,6 +298,30 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); } +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024}); + + HloInstruction* key = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "key")); + HloInstruction* value = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "value")); + + HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value})); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(sort->operand(0)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32); + EXPECT_NE(computation->root_instruction(), sort); + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple); +} + // Tests that the normalization should not cause unsupported mixed precision due // to resolving unsupported BF16 operand. TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { @@ -318,10 +342,10 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module)); + EXPECT_TRUE(Normalize(module.get())); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 002be9c97098ef1f73446c458dae24bbc826a626..63d4572f2028c462df1cac9d5e4ee616e407f37b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -236,6 +236,10 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // the end of the BFloat16Propagation pass. continue; } + if (use.instruction->HasSideEffectNoRecurse()) { + // Keep side-effecting instruction's operands unchanged. + return false; + } // Any visited user that can accept BF16 has already been updated if // necessary, e.g., the output has been changed to BF16 if it propagates // precision, or a called computation's parameters have been changed to @@ -329,22 +333,6 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } - // Do not change precision for instructions related to entry and exit of a - // computation, and control flow, because this pass might break the interfaces - // or assumptions for them. - if (hlo->opcode() == HloOpcode::kInfeed || // - hlo->opcode() == HloOpcode::kOutfeed || // - hlo->opcode() == HloOpcode::kSend || // - hlo->opcode() == HloOpcode::kSendDone || // - hlo->opcode() == HloOpcode::kRecv || // - hlo->opcode() == HloOpcode::kRecvDone || // - hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kConditional || // - (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { - return; - } - // Prevent root instructions from having their output modified by recording // all F32 output values as needing to stay as F32. CHECK(hlo->parent() != nullptr); @@ -366,6 +354,17 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } + // Do not change precision for instructions related to entry and exit of a + // computation, side-effecting instructions, and control flow, because this + // pass might break the interfaces or assumptions for them. + if (hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kConditional || // + hlo->HasSideEffectNoRecurse() || // + (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { + return; + } + if (!ContainsKey(consider_using_bfloat16_, hlo)) { return; } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index e032b5c624c0151fd63c870e0f21ec97656d625f..5be7141aae423adb4fe2f39262e463ff25ae8234 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,11 +55,11 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloVerifiedTestBase { +class BFloat16PropagationTest : public HloTestBase { protected: BFloat16PropagationTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. @@ -121,10 +121,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,6 +136,96 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { EXPECT_FALSE(OutputsBF16(c)); } +TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) { + auto module = CreateNewVerifiedModule(); + + auto sub_builder = HloComputation::Builder("max"); + HloInstruction* p0 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a")); + HloInstruction* p1 = sub_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b")); + sub_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1)); + auto max_computation = module->AddEmbeddedComputation(sub_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_high(1); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + HloInstruction* rw = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + shape, add, + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), + window, max_computation)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0})); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(rw)); +} + +// Tests that side-effecting all-reduce should not be changed. +TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) { + auto module = CreateNewVerifiedModule(); + + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + auto rb = HloComputation::Builder(TestName()); + rb.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")), + rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")))); + auto reduction = module->AddEmbeddedComputation(rb.Build()); + HloInstruction* all_reduce = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction, + /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, all_reduce, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, all_reduce, 1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), root); +} + // Tests that if a constant is converted to BF16 then its literal must also be // converted. TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { @@ -152,10 +242,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -208,10 +298,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -247,10 +337,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction* dot = builder.AddInstruction( CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -276,10 +366,10 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -287,7 +377,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { // Tests that BF16 is propagated properly through fused computations. TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -322,7 +412,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -335,7 +425,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { // Tests that changes to BF16 that cannot be propagated outside a fusion are // discarded. TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -359,7 +449,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -374,7 +464,7 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { // (BF16, BF16) fusion_computation(F32 a, F32 b) // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -405,7 +495,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -424,7 +514,7 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { // on_true and on_false must match, so that as long as one of them is F32, the // other must be F32 as well. TEST_F(BFloat16PropagationTest, SelectOverTuples) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -455,7 +545,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -468,7 +558,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { // Tests that BF16 is propagated properly through a while computation with // non-tuple input/output. TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -511,7 +601,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -527,7 +617,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { // made to the while body and thus the fusion node inside it. TEST_F(BFloat16PropagationTest, ConditionPreventsPropagationForFusionInsideWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -576,7 +666,7 @@ TEST_F(BFloat16PropagationTest, auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module)); + EXPECT_FALSE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -588,7 +678,7 @@ TEST_F(BFloat16PropagationTest, // Tests that BF16 is propagated properly through while computations with // tuple-shaped input/output. TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -656,7 +746,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -675,7 +765,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // Tests that BF16 is not propagated through multiple whiles that invoke the // same computation as long as one while prevents the propagation. TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); @@ -786,7 +876,7 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -825,10 +915,10 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( bf16_shape, HloOpcode::kAdd, convert0, convert1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -861,10 +951,10 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -907,10 +997,10 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module)); + EXPECT_TRUE(PropagatePrecision(module.get())); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 5b48f10505e78c035608d4c575501e4623218987..2b9502f63a821f3675ddfb506f41bb2390cf4136 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -107,6 +108,21 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kSelect: case HloOpcode::kTupleSelect: return operand_index == 1 || operand_index == 2; + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: { + HloComputation* reduce_comp = hlo.called_computations()[0]; + for (HloInstruction* inst : reduce_comp->instructions()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) { + return false; + } + } + } + return true; + } default: break; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 2c2d1626c2c0d5d4b13e401dad9fd6c51514fc13..8d7c62447852fd946440c41389300a92377c471f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size) { - VLOG(4) << "Trying to add " << buffer << " to " << this; + VLOG(4) << "Trying to add " << buffer << " to allocation #" << index(); CHECK(assigned_buffers_.count(&buffer) == 0) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -378,6 +378,20 @@ const BufferAllocation& BufferAssignment::GetAllocation( return allocations_[index]; } +const BufferAllocation* BufferAssignment::GetInstructionAllocation( + const HloInstruction* hlo, const ShapeIndex& shape_index) const { + const PointsToSet& points_to_set = points_to_analysis().GetPointsToSet(hlo); + const LogicalBuffer* buffer = points_to_set.element(shape_index)[0]; + + if (!HasAllocation(*buffer)) { + return nullptr; + } + + const BufferAllocation& instruction_allocation = + GetAssignedAllocation(*buffer); + return &instruction_allocation; +} + BufferAllocation* BufferAssignment::GetMutableAllocation( BufferAllocation::Index index) { return const_cast(&GetAllocation(index)); @@ -514,6 +528,9 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer)); allocation->AddAssignment(buffer, offset, size); + if (liveness().MaybeLiveOut(buffer)) { + allocation->set_maybe_live_out(true); + } allocation_index_for_buffer_[&buffer] = allocation->index(); } @@ -624,7 +641,7 @@ Status BufferAssignment::ComputeSummaryStats() { bool schedule_complete = true; for (const auto& computation : module_->computations()) { if (!computation->IsFusionComputation()) { - const std::vector* sequence = + const HloInstructionSequence* sequence = liveness_->hlo_ordering().SequentialOrder(*computation); if (sequence == nullptr) { schedule_complete = false; @@ -728,14 +745,89 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing, bool allocate_buffers_for_constants, - BufferLiveness::Colorer colorer) { - BufferAssigner assigner(allow_input_output_aliasing, - allocate_buffers_for_constants, std::move(colorer)); + BufferLiveness::Colorer colorer, ReuseAllocationFunction reuse_checker) { + BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), + std::move(reuse_checker)); return assigner.CreateAssignment(module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment)); } +namespace { + +// a and b are in different subcomputations. Check for the case +// where a is inside the while body, and b is outside, part of the same while's +// init-operand or while-result. +bool MayInterfereAcrossSubcomputations(BufferAssignment* assignment, + const LogicalBuffer& a_buffer, + const LogicalBuffer& b_buffer) { + auto call_graph = assignment->liveness().hlo_ordering().call_graph(); + const HloInstruction* a_ancestor; + const HloInstruction* b_ancestor; + std::tie(a_ancestor, b_ancestor) = + call_graph.NearestAncestorsInSameComputation(a_buffer.instruction(), + b_buffer.instruction()); + if (a_ancestor == nullptr) { + // No common ancestor. + return true; + } + if (a_ancestor->opcode() == HloOpcode::kWhile && + call_graph.InstructionIsNestedIn(a_buffer.instruction(), + a_ancestor->while_body())) { + const PointsToSet& init_set = + assignment->liveness().points_to_analysis().GetPointsToSet( + a_ancestor->operand(0)); + if (init_set.ContainsBuffer(b_buffer)) { + VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer + << " (part of while-operand)"; + return false; + } + const PointsToSet& while_set = + assignment->liveness().points_to_analysis().GetPointsToSet(a_ancestor); + if (while_set.ContainsBuffer(b_buffer)) { + VLOG(4) << "Can't interfere: " << a_buffer << " and " << b_buffer + << " (part of while)"; + return false; + } + } + return true; +} + +// Return true, if a and b can't possibly interfere (and therefore further +// checking for interference can be skipped). This function checks for special +// cases where copy insertion guarantees no interference, but the regular buffer +// liveness is too conservative: +// +// Operations inside a while-body can't interfere with operations outside the +// while op if their last use is at the while-loop itself as part of the +// while-init op, or the while-result. For ops that are live across a +// while-loop, copy insertion will already insert the necessary copies to avoid +// such interference. +// +// This allows sharing buffers in cases like this: +// init = {...} +// while (init): +// p = param(0) +// gte = get-tuple-element(p), index=i +// t1 = op1 (gte) +// t2 = op2 (t1) +// ROOT tuple = {..., t2, ...} +// +// where t1 and t2 can share the same buffer. +bool MaySkipInterferenceCheck(BufferAssignment* assignment, + const LogicalBuffer& a_buffer, + const LogicalBuffer& b_buffer) { + if (a_buffer.instruction()->parent() == b_buffer.instruction()->parent()) { + // Ops within the same computation are not handled here. Assume that they + // may interfere. + return false; + } + return !MayInterfereAcrossSubcomputations(assignment, a_buffer, b_buffer) || + !MayInterfereAcrossSubcomputations(assignment, b_buffer, a_buffer); +} + +} // namespace + bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, const LogicalBuffer& buffer, BufferAssignment* assignment) { @@ -763,6 +855,12 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return false; } + if (reuse_checker_ != nullptr && + !reuse_checker_(*assignment, *allocation, buffer)) { + VLOG(4) << "Can't assign: reuse_checker_(allocation, buffer) == false"; + return false; + } + if (!allocation->is_reusable()) { VLOG(4) << "Can't assign: allocation is not reusable"; return false; @@ -770,6 +868,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, for (const auto& buffer_offset_size : allocation->assigned_buffers()) { const LogicalBuffer& assigned_buffer = *buffer_offset_size.first; + if (MaySkipInterferenceCheck(assignment, buffer, assigned_buffer)) { + continue; + } if (assignment->liveness().MayInterfere(assigned_buffer, buffer)) { VLOG(4) << "Can't assign: assignee " << assigned_buffer << " may interfere with " << buffer; @@ -784,21 +885,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } } - if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - const HloComputation* entry_computation = - assignment->module_->entry_computation(); - for (auto param : entry_computation->parameter_instructions()) { - for (auto& param_buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - param)) { - if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { - VLOG(4) << "Can't assign: Parameter interference with result"; - return false; - } - } - } - } - // If the buffer is live out of the computation then it should only be // assigned a buffer which exactly fits the result to avoid wasting memory // (result buffers can have arbitrary lifetimes). @@ -1093,7 +1179,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); schedule.set_sequence(computation, *instruction_sequence); @@ -1128,7 +1214,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloComputation* computation = pair.first; const flat_hash_set& buffers_to_assign = pair.second; - const std::vector* instruction_sequence = + const HloInstructionSequence* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); auto color_map = SplitBuffersByColor(buffers_to_assign); @@ -1143,7 +1229,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, - HloInstructionSequence(*instruction_sequence), + *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1347,33 +1433,40 @@ BufferAssigner::MergeColocatedBufferSets( computation == module->entry_computation(); }; + std::vector set_can_be_merged(colocated_buffer_sets.size(), true); + + // Do not merge if one of the sets includes live outs, entry parameters or + // constants. + // + // Buffer liveness does not report the correct live range for entry + // parameter and live out buffers so we have to special case them here. On + // backends that support constant buffer allocations, constant buffers are + // assigned globals in readonly storage so we can't merge colocated buffer + // sets containing constants with colocated buffer sets containing writing + // instructions or other constants. + // + // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to + // the caller of the executable so we can't write to entry parameters + // either, and the argument for not merging constants also applies to entry + // parameters. + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (auto& buffer : colocated_buffer_sets[i]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer) || + buffer->instruction()->opcode() == HloOpcode::kConstant) { + set_can_be_merged[i] = false; + break; + } + } + } + // Returns true if the two colocated buffer sets (specified by their indices // into the colocated_buffer_sets) can be merged into a single set. auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, - &is_entry_parameter](int64 i, int64 j) { - // Do not merge if one of the sets includes live outs, entry parameters or - // constants. - // - // Buffer liveness does not report the correct live range for entry - // parameter and live out buffers so we have to special case them here. On - // backends that support constant buffer allocations, constant buffers are - // assigned globals in readonly storage so we can't merge colocated buffer - // sets containing constants with colocated buffer sets containing writing - // instructions or other constants. - // - // Moreover (on the CPU/GPU backends) the entry parameter buffers belong to - // the caller of the executable so we can't write to entry parameters - // either, and the argument for not merging constants also applies to entry - // parameters. - for (int64 key : {i, j}) { - for (auto& buffer : colocated_buffer_sets[key]) { - if (buffer_liveness.MaybeLiveOut(*buffer) || - is_entry_parameter(*buffer) || - buffer->instruction()->opcode() == HloOpcode::kConstant) { - return true; - } - } + &set_can_be_merged](int64 i, int64 j) { + if (!set_can_be_merged[i] || !set_can_be_merged[j]) { + return true; } // Colocated sets satisfy the invariant that all buffers within a set have @@ -1434,13 +1527,30 @@ BufferAssigner::MergeColocatedBufferSets( // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile, kCall, and -// kConditional). +// kConditional and input output aliasing). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { const TuplePointsToAnalysis& points_to_analysis = buffer_liveness.points_to_analysis(); + + // Set up colocated buffer set for input and output. + VLOG(4) << "Input/Output Alias Config: "; + VLOG(4) << module->input_output_alias_config(); + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + std::vector colocated_set; + AddBufferToColocatedSet(module->entry_computation()->root_instruction(), + output_index, points_to_analysis, + &colocated_set); + AddBufferToColocatedSet( + module->entry_computation()->parameter_instruction(param_number), + param_index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { if (computation->IsFusionComputation()) { continue; @@ -1574,6 +1684,13 @@ void BufferAssigner::BuildColocatedBufferSets( return; } + int64 i = 0; + for (const auto& colocated_set : *colocated_buffer_sets) { + VLOG(4) << "Colocated set " << i++ << ":"; + for (const auto& buffer : colocated_set) { + VLOG(4) << " " << buffer->ToString(); + } + } // Try to find more coalescing opportunities among the colocated buffer sets. // // TODO(b/32491382): We should be able to remove this by using the diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 899cd36e1f98c9e7b8ba7e42c06ced5c3e8afcc8..0a9fdede803e84ca42472259084615c031b206eb 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -362,6 +362,11 @@ class BufferAssignment { // with the given index. const BufferAllocation& GetAllocation(BufferAllocation::Index index) const; + // Returns the allocation with the given instruction and shape index. nullptr + // if no allocation exists. + const BufferAllocation* GetInstructionAllocation( + const HloInstruction* hlo, const ShapeIndex& shape_index) const; + // Builds and returns a vector containing the slices which might contain the // subvalue at the given index of given instruction. std::set GetAllSlices( @@ -520,6 +525,11 @@ class BufferAssignment { // A class which constructs a buffer assignment. class BufferAssigner { public: + // Returns false if a buffer cannot be assigned to given allocation. + using ReuseAllocationFunction = std::function; + // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size and // color_alignment are functions which returns the size and alignment of a @@ -531,15 +541,16 @@ class BufferAssigner { LogicalBuffer::AlignmentFunction color_alignment, bool allow_input_output_aliasing = false, bool allocate_buffers_for_constants = false, - BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer()); + BufferLiveness::Colorer colorer = BufferLiveness::DefaultColorer(), + ReuseAllocationFunction reuse_checker = nullptr); private: - BufferAssigner(bool allow_input_output_aliasing, - bool allocate_buffers_for_constants, - BufferLiveness::Colorer colorer) - : allow_input_output_aliasing_(allow_input_output_aliasing), - allocate_buffers_for_constants_(allocate_buffers_for_constants), - colorer_(colorer) {} + BufferAssigner(bool allocate_buffers_for_constants, + BufferLiveness::Colorer colorer, + ReuseAllocationFunction reuse_checker) + : allocate_buffers_for_constants_(allocate_buffers_for_constants), + colorer_(colorer), + reuse_checker_(reuse_checker) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. @@ -627,16 +638,15 @@ class BufferAssigner { LogicalBuffer::Color::Hasher> SplitBuffersByColor(const absl::flat_hash_set& buffers); - // If true, buffer assignments assumes that input parameter buffers and output - // buffers can be shared if their sizes match. - bool allow_input_output_aliasing_; - // If true, allocate buffers for constant instructions. bool allocate_buffers_for_constants_; // Functor used to assign colors to newly allocated logical buffers. BufferLiveness::Colorer colorer_; + // Functor to check if a buffer can reuse an allocation. + ReuseAllocationFunction reuse_checker_; + TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 795beb9ff5ceb2998a85fbd03d8bb1d3b2febc12..8f482e6ba8c3e71c9980be5e6947ea61f3b4ef29 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloVerifiedTestBase { +class BufferAssignmentTest : public HloTestBase { protected: ~BufferAssignmentTest() override {} @@ -107,6 +107,24 @@ class BufferAssignmentTest : public HloVerifiedTestBase { .ConsumeValueOrDie(); } + std::unique_ptr RunBufferAssignmentNoBuffersReuseForAdd( + HloModule* module, int64 alignment = 1) { + auto reuse_checker = [](const BufferAssignment& assignment, + const BufferAllocation& alloc, + const LogicalBuffer& buffer) { + return (buffer.instruction()->opcode() != HloOpcode::kAdd); + }; + return BufferAssigner::Run( + module, absl::make_unique(module), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/false, + /*colorer=*/BufferLiveness::DefaultColorer(), + /*reuse_checker=*/reuse_checker) + .ConsumeValueOrDie(); + } + std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( @@ -119,8 +137,7 @@ class BufferAssignmentTest : public HloVerifiedTestBase { } std::unique_ptr RunBufferAssignmentWithInstructionSequence( - HloModule* module, - absl::Span instruction_sequence, + HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { HloSchedule schedule(module); schedule.set_sequence(module->entry_computation(), instruction_sequence); @@ -316,16 +333,16 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -340,17 +357,17 @@ TEST_F(BufferAssignmentTest, BufferForConst) { LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -369,10 +386,10 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate, param0, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -392,10 +409,10 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -421,10 +438,10 @@ TEST_F(BufferAssignmentTest, Basic) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -447,6 +464,56 @@ TEST_F(BufferAssignmentTest, Basic) { GetAssignedOutputAllocation(*buffers, sub); } +TEST_F(BufferAssignmentTest, AddCannotReuse) { + // Pass in a special rule to indicate that "add" cannot reuse any buffer. + // + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "p1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "p2")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignmentNoBuffersReuseForAdd(module.get()); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + + // The mul node has a valid buffer assigned, doesn't share with input. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + + // The add node cannot reuse the mul node's buffer since we told buffer + // assignment so. + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_NE(add_buffer.index(), mul_buffer.index()); + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); +} + TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // paramscalar ------- (mul) -- (add) -- (sub) // / / / @@ -470,7 +537,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -485,7 +552,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -531,7 +598,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto colorer = [](const BufferLiveness& buffer_liveness) { @@ -554,7 +621,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module, colorer); + auto buffers = RunColoredBufferAssignment(module.get(), colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -603,10 +670,10 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -638,7 +705,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // param0[100x10] ---> (map x+1) // // Builds the map function. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto map_computation = module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); auto inner_last = map_computation->root_instruction(); @@ -657,7 +724,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -693,7 +760,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // out-of-order reductions could overwrite an element before a use.) // // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto reduce_computation = module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); @@ -716,7 +783,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -744,7 +811,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // const4[f32[4]] --- tuple --- while[condition, body] // // Builds the nested condition and body. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto condition_computation = module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4")); auto body_computation = @@ -772,7 +839,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -810,7 +877,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { } TEST_F(BufferAssignmentTest, ExampleConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto true_computation = module->AddEmbeddedComputation( BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); auto false_computation = module->AddEmbeddedComputation( @@ -837,7 +904,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -873,9 +940,9 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto neg = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -902,9 +969,9 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -935,9 +1002,9 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -972,9 +1039,9 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { HloInstruction::CreateBroadcast(f32a100x10_, slice, {1})); builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1007,9 +1074,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1039,9 +1106,9 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {10, 10}), slice, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1077,9 +1144,9 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { ShapeUtil::MakeShape(F32, {10, 4}), slice, {0})); builder.AddInstruction(HloInstruction::CreateTuple({broadcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1092,7 +1159,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { // Verify that buffers for embedded computations are properly marked as // thread-local and that embedded parameters are not marked as // is_entry_computation_parameter. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto vec_shape = ShapeUtil::MakeShape(F32, {42}); auto scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -1123,7 +1190,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Allocations for the map computation should be thread-local and not // live-out. @@ -1170,9 +1237,9 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { ShapeUtil::MakeShape(S32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1206,9 +1273,9 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1250,9 +1317,9 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1264,9 +1331,9 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}), ShapeUtil::MakeShape(S32, {101})}), /*operands=*/{}, /*custom_call_target=*/"foo_function")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1279,7 +1346,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { TEST_F(BufferAssignmentTest, TupleCallAsOutput) { // Test a computation which returns a tuple call value. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1297,7 +1364,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1320,7 +1387,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { // B: call(C, param) // C: call(D, param) // D: param - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto elem_shape = f32vec4_; auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape}); @@ -1359,7 +1426,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1393,9 +1460,9 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto bitcast = builder.AddInstruction( HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1420,9 +1487,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect, pred_param, tuple_param0, tuple_param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1458,9 +1525,9 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto copy = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape, HloOpcode::kCopy, tuple_element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module); + auto assignment = RunBufferAssignment(module.get()); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1500,9 +1567,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module, /*alignment=*/1); + auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1521,7 +1588,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module, /*alignment=*/64); + assignment = RunBufferAssignment(module.get(), /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1564,10 +1631,10 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kSubtract, add, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = @@ -1605,11 +1672,11 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module, {param, log, rev, neg, concat, root}); + module.get(), {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1630,7 +1697,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { } TEST_F(BufferAssignmentTest, PeakBuffersWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); HloComputation* condition; { @@ -1665,7 +1732,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module); + auto buffers = RunBufferAssignment(module.get()); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1715,13 +1782,13 @@ ENTRY main { } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* constant_1 = - module().entry_computation()->GetInstructionWithName("constant.1.1"); + m->entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module().entry_computation()->GetInstructionWithName("constant.1.2"); + m->entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); { const BufferAllocation& allocation_for_const_1 = @@ -1750,7 +1817,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloVerifiedTestBase { +class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1785,7 +1852,7 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, absl::make_unique(schedule), ByteSizeOf, @@ -1810,7 +1877,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1849,8 +1916,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1906,20 +1973,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* param = - module().entry_computation()->parameter_instruction(0); + m->entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1927,7 +1993,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -1974,20 +2040,19 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module().instruction_count(); + int64 instruction_count = m->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(&module()).status()); - ASSERT_EQ(instruction_count, module().instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(m.get()).status()); + ASSERT_EQ(instruction_count, m->instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = - module().entry_computation()->root_instruction(); + const HloInstruction* bcast = m->entry_computation()->root_instruction(); const HloInstruction* constant = - module().entry_computation()->GetInstructionWithName("constant.42"); + m->entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1995,7 +2060,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(&module()); + auto assignment = RunBufferAssignment(m.get()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2053,7 +2118,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto token = builder.AddInstruction(HloInstruction::CreateToken()); @@ -2088,7 +2153,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module).status()); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry @@ -2096,7 +2161,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // nodes are traversed during BufferAssignment. TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -2107,12 +2172,12 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2134,7 +2199,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2166,8 +2231,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2179,7 +2244,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -2209,13 +2274,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); } - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2240,13 +2305,14 @@ ENTRY Main { )"; HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_text, config); + config.set_debug_options(GetDebugOptionsFromFlags()); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(hlo_text, config)); - auto buffers = RunBufferAssignment(&module()); + auto buffers = RunBufferAssignment(m.get()); - HloComputation* main = module().entry_computation(); - HloComputation* callee = module().GetComputationWithName("Callee"); + HloComputation* main = m->entry_computation(); + HloComputation* callee = m->GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2270,7 +2336,7 @@ ENTRY Main { } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2317,40 +2383,41 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module); + RunCopyInsertion(module.get()); HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleModule(module.get(), ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - schedule.set_sequence(module->entry_computation(), - {input1, weights1, one, output1, while1->operand(0), - while1, input0, weights0, zero, output0, - while0->operand(0), while0, gte0, gte1, root_add}); + schedule.set_sequence( + module->entry_computation(), + {input1, weights1, one, output1, while1->mutable_operand(0), while1, + input0, weights0, zero, output0, while0->mutable_operand(0), while0, + gte0, gte1, root_add}); // If this ASSERT fails, we constructed a bogus sequence above and this test // itself is buggy. TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run(module, - absl::make_unique(schedule), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run( + module.get(), absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -2394,8 +2461,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module); - auto assignment = RunBufferAssignment(module); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() @@ -2406,5 +2473,58 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { EXPECT_FALSE(root_alloc->is_entry_computation_parameter()); } +TEST_F(WhileBufferAssignmentTest, WhileWithDynamicUpdateSliceShare) { + const char* const hlo_string = R"( +HloModule test + +while_body { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + constant.3 = s32[3]{0} constant({0, 0, 0}) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3) + ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) +} + +while_condition { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + get-tuple-element = s32[] get-tuple-element(state), index=0 + get-tuple-element.1 = s32[] constant(3) + ROOT less-than.339.338 = pred[] less-than(get-tuple-element, get-tuple-element.1) +} + +ENTRY entry_computation { + constant.7 = s32[] constant(0) + copy.1 = s32[] copy(constant.7) + constant.6 = f32[] constant(0) + broadcast.6 = f32[1280,1,128]{2,1,0} broadcast(constant.6), dimensions={} + tuple.1 = (s32[], f32[1280,1,128]{2,1,0}) tuple(copy.1, broadcast.6) + while.0 = (s32[], f32[1280,1,128]{2,1,0}) while(tuple.1), condition=while_condition, body=while_body + ROOT get-tuple-element.2 = s32[] get-tuple-element(while.0), index=0 +} + +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + // Get BufferAllocation for root instruction. + auto dus9 = FindInstruction(module.get(), "dynamic-update-slice.9"); + auto dus9_alloc_slice = + assignment->GetUniqueTopLevelSlice(dus9).ConsumeValueOrDie(); + auto dus5 = FindInstruction(module.get(), "dynamic-update-slice.5"); + auto dus5_alloc_slice = + assignment->GetUniqueTopLevelSlice(dus5).ConsumeValueOrDie(); + // Test that the two dynamic-update-slice ops share the same allocation slice. + EXPECT_EQ(dus9_alloc_slice.allocation(), dus5_alloc_slice.allocation()); + EXPECT_EQ(dus9_alloc_slice, dus5_alloc_slice); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 17e50905059ad2c92784d14132c1cb1f46f35ade..40825a78716b1c0b9fb0121787977d275891c0f8 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -117,7 +117,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto log = builder.AddInstruction( HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -164,7 +164,7 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -213,7 +213,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -247,7 +247,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -289,7 +289,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloSchedule schedule(module.get()); @@ -336,7 +336,7 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1)); auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(add)); HloSchedule schedule(module.get()); @@ -373,7 +373,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto outer_tuple = builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -393,7 +393,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { TEST_F(BufferLivenessTest, EmbeddedComputation) { // Test MaybeLiveOut and MayInterfere for embedded computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); auto embedded_param = embedded_builder.AddInstruction( @@ -450,7 +450,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0.shape(), tuple_constant, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto liveness = @@ -514,7 +514,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -576,7 +576,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto tuple_root = builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); @@ -611,8 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - std::unique_ptr BuildModule(const bool update_uses_tuple_element1, - const bool fuse_gte0) { + std::unique_ptr BuildModule( + const bool update_uses_tuple_element1, const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -646,7 +646,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. @@ -802,7 +802,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { auto tuple_root = builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index 69b36463560a1fad4f62687e9014fb3fbe5bbd13..11d8abc5badf7b1a05239ed74a05be0c899e37a1 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -141,6 +141,9 @@ class BufferValue { // operator< is required for std::set. bool operator<(const BufferValue& other) const { return id_ < other.id_; } + bool operator==(const BufferValue& other) const { return id_ == other.id_; } + bool operator!=(const BufferValue& other) const { return id_ != other.id_; } + virtual string ToString() const = 0; // TODO(lauj) rename LogicalBufferProto to BufferValueProto. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index bdd5069632e84fe6c67ca129f726432479ac1b35..7987343bfaf1069fd550909d127e4b11f2124701 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -325,6 +325,15 @@ bool CallGraph::IsFlattened() const { return true; } +std::vector CallGraph::GetComputationCallers( + HloComputation* c) { + std::vector callers; + for (auto callsite : GetNode(c).caller_callsites()) { + callers.push_back(callsite.instruction()); + } + return callers; +} + std::pair CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, HloInstruction* b) const { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index cb56f4789d06ac33acdaadc8b619b9e37f683d58..05c7c998738f861ee804d1ec87bfa5fb17ddfb74 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -236,6 +236,10 @@ class CallGraph { // FlattenCallGraph. bool IsFlattened() const; + // Returns a vector of instructions calling the passed computation. + // (Often a vector of size 1.) + std::vector GetComputationCallers(HloComputation* c); + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 34f3f914d593bc603c4964663f9cafb70a136fd3..a3ac2568b0f3eec8556a42dbe3c2c64bd8564468 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloVerifiedTestBase { +class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -93,10 +93,10 @@ class CallGraphTest : public HloVerifiedTestBase { TEST_F(CallGraphTest, SingletonComputation) { // Test the call graph of a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -112,13 +112,13 @@ TEST_F(CallGraphTest, SingletonComputation) { TEST_F(CallGraphTest, UnreachableComputation) { // Test the call graph of a module with an entry computation and an // unreachable computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -134,13 +134,13 @@ TEST_F(CallGraphTest, UnreachableComputation) { TEST_F(CallGraphTest, ParallelComputation) { // Test a call graph of a module with an entry computation which calls another // computation in a parallel context via kMap. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* map_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -163,13 +163,13 @@ TEST_F(CallGraphTest, ParallelComputation) { TEST_F(CallGraphTest, SequentialComputations) { // Test a call graph of a module with an entry computation which calls another // computation in a sequential context via kCall. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* called_computation = module->AddEmbeddedComputation(MakeScalarComputation()); HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -196,7 +196,7 @@ TEST_F(CallGraphTest, SequentialComputations) { TEST_F(CallGraphTest, ContextBothComputations) { // Test a call graph of a module with an entry computation which calls another // computation in both a parallel and sequential context. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -239,7 +239,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { TEST_F(CallGraphTest, ComputationWithConditional) { // Test a call graph of a module with a conditional. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* true_computation = module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); HloComputation* false_computation = @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(3, call_graph->nodes().size()); @@ -298,7 +298,7 @@ TEST_F(CallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -418,7 +418,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -479,10 +479,10 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { TEST_F(CallGraphTest, VisitSingletonComputation) { // Test the call graph visitor with a call graph with a single node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -494,12 +494,12 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { TEST_F(CallGraphTest, VisitUnreachableComputation) { // Test the call graph visitor with a call graph with an unreachable node. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // Test visitation of only reachable nodes. { @@ -531,9 +531,9 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e6b566543594a86eb5369ee9b7440f62618f6c5a..0b6e323f75c7a5dae127e20d2a4b92a83a72df3b 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloVerifiedTestBase; +using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -51,7 +51,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloInstruction* one = inner.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); TF_ASSERT_OK(zero->AddControlDependencyTo(one)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -79,7 +79,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // returns false should be identical to just returning false). TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create a lambda that calls a function that returns the false predicate. // Note we also use this lambda twice by reference, just to make the test a @@ -107,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -120,7 +120,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { // whole pass. TEST_F(CallInlinerTest, InlineWithoutRunningPass) { const Shape pred = ShapeUtil::MakeShape(PRED, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder just_false(TestName() + ".false"); auto* true_constant = just_false.AddInstruction( @@ -144,7 +144,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { const Shape f32 = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder outfeeder(TestName() + ".outfeeder"); auto value = outfeeder.AddInstruction( @@ -163,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..2662fe46705f4936ce0d654df0943e7d30890ebe --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.cc @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compilation_cache.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = counter++; + return id; +} + +} // namespace + +ExecutionHandle CompilationCache::Insert( + std::unique_ptr executable) { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = GetUniqueId(); + VLOG(2) << "inserting cache key: " << key; + CHECK_EQ(cache_.count(key), 0); + cache_.emplace(key, std::move(executable)); + + ExecutionHandle handle; + handle.set_handle(key); + return handle; +} + +StatusOr> CompilationCache::LookUp( + const ExecutionHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + + CacheKey key = handle.handle(); + VLOG(2) << "looking up cache key: " << key; + if (cache_.count(key) == 0) { + VLOG(2) << "cache key not found: " << key; + return InvalidArgumentStrCat("can not find executable with handle ", key); + } else { + auto& result = cache_.at(key); + VLOG(2) << "hit executable: " << result->module().name(); + return result; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..5f94def509d4d4a8950272cb498af5056a698ce0 --- /dev/null +++ b/tensorflow/compiler/xla/service/compilation_cache.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace xla { + +// A cache which stores Executables indexed by computation handle and version. +// +// TODO(b/119042872): Provide mechanism for removing computations from the +// compilation cache. +class CompilationCache { + public: + CompilationCache() {} + + ExecutionHandle Insert(std::unique_ptr executable); + + // Lookup the Executable for the specified handle in the cache. Return a + // shared_ptr to the Executable if it exists in the cache. + StatusOr> LookUp( + const ExecutionHandle& handle) const; + + protected: + mutable tensorflow::mutex mutex_; + + using CacheKey = int64; + + absl::flat_hash_map> cache_ + GUARDED_BY(mutex_); + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 96bd2616f5607de888a096f8392ceb68490276e3..1965925fa7f6d50b1d7af918bc3468d4b4d5d0a2 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -67,7 +67,7 @@ CompileOnlyService::CompileAheadOfTime( std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { - TF_RET_CHECK(instance.computation.has_program_shape()); + TF_RET_CHECK(instance.computation.has_host_program_shape()); const DebugOptions& debug_options = options.debug_options(); @@ -86,13 +86,15 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; + *execution_options.mutable_shape_with_output_layout() = + instance.result_layout->ToProto(); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(program_shape, instance.argument_layouts, - &execution_options)); + CreateModuleConfig( + ProgramShape(instance.computation.host_program_shape()), + instance.argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, @@ -101,8 +103,10 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, - metadata); + return compiler_->CompileAheadOfTime( + absl::make_unique(hlo_modules[0]->name(), + absl::MakeSpan(hlo_modules)), + options, metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 687ecafe0c308ecc22857fae650c6998677f605d..8f08c244908efb823b3870c19bdc3491fa87d44f 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -45,7 +45,7 @@ Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, // Define a default version where metadata is not used. StatusOr>> Compiler::CompileAheadOfTime( - std::vector> modules, + std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata) { if (metadata != nullptr) { @@ -53,7 +53,7 @@ Compiler::CompileAheadOfTime( "Populating AotCompilationMetadata is not implemented on this " "compiler."); } - return CompileAheadOfTime(std::move(modules), options); + return CompileAheadOfTime(std::move(module_group), options); } /* static */ std::map* @@ -110,6 +110,6 @@ Compiler::GetPlatformCompilers() { } AotCompilationOptions::AotCompilationOptions() - : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + : debug_options_(GetDebugOptionsFromFlags()) {} } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 1fdda31c34a17a16f75e1efada542c2c2ea15038..d4db95da8eb901af8a6675f2991def73ccfe8ee6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -135,22 +136,35 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Optimizes a HLO module group, a set of module which runs concurrently on + // multiple devices potentially communicating data between the modules. + virtual Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, + absl::Span executors, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses // prior to calling this method because some HLO passes are required for - // correctness. Takes ownership of the HLO module and is free to transform it. + // correctness. Takes ownership of the HLO module. // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // // device_allocator is optional; see RunHloPasses. - // - // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially + // communicating data between the modules. + virtual StatusOr>> + RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; + // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. @@ -160,7 +174,7 @@ class Compiler { // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; @@ -184,16 +198,16 @@ class Compiler { ComputeDefaultBackendConfig(const HloInstruction& hlo, se::StreamExecutor* executor) const; - // Compiles the HLO module for ahead-of-time execution. This is intended for - // use in static compilation. + // Compiles the HLO module group for ahead-of-time execution. This is + // intended for use in static compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) = 0; // Similar to CompileAheadOfTime above but AotCompilationMetadata // has an argument that can be populated during compilation. virtual StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index af8f7f1027a40703137d6880a9865449c560a47b..efc893818d03a20d6bd65b7dc1da72ea5da5ceb0 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -56,4 +56,14 @@ string ComputationLayout::ToString() const { result_layout_.ToString()); } +ProgramShape ComputationLayout::ComputeProgramShape() const { + ProgramShape program_shape; + for (int64 i = 0; i < parameter_layouts_.size(); ++i) { + *program_shape.add_parameters() = parameter_layouts_[i].shape(); + *program_shape.add_parameter_names() = absl::StrCat("p", i); + } + *program_shape.mutable_result() = result_layout_.shape(); + return program_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 6975f387b4864bf28ea0ad23d7d4602b5b346e08..a2fb656677f354fbf85ff613d826cd6be86ba3bf 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -83,6 +83,10 @@ class ComputationLayout { // Returns a string representation of this object. string ToString() const; + // Create a ProgramShape proto based on the parameter and result shapes held + // within this object. + ProgramShape ComputeProgramShape() const; + private: std::vector parameter_layouts_; ShapeLayout result_layout_; diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index c899ffb9dc562426ef14c0d414469c04debeec70..844b42a38d7539cccd5c4e30071c0ea6693e3bba 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -105,8 +105,6 @@ class ComputationPlacer { // Map from platform kind to computation placer singleton. static std::map* GetPlatformComputationPlacers(); - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167d47af3c92ed35fa52594fa5da1e4af..289eb6d90239a72ecc0f3312a7e0e8453f946858 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -37,7 +37,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalSimplifierTest : public HloVerifiedTestBase { +class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); @@ -96,25 +96,28 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { - HloComputation* computation = MakeConditional(&module()); - ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); + ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Constant())); } TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* true_op = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK( true_op->AddControlDependencyTo(computation->root_instruction())); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -125,11 +128,12 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); @@ -138,18 +142,19 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeConditional(&module()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get()); auto* conditional = computation->root_instruction(); ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); auto* false_computation = conditional->false_computation(); auto token = false_computation->AddInstruction(HloInstruction::CreateToken()); false_computation->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2..95c7724c3c93507ae61a984301ecfc0111bef192 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -51,7 +51,8 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { Status HandleConvolution(HloInstruction* convolution) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation); + static bool Run(HloComputation* computation, + bool canonicalize_depthwise_filter); // Returns whether any convolution ops were rewritten. const bool changed() const { return changed_; } @@ -59,18 +60,24 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault { ~ConvolutionVisitor() override = default; private: - explicit ConvolutionVisitor(HloComputation* computation) - : computation_(computation) {} + explicit ConvolutionVisitor(HloComputation* computation, + bool canonicalize_depthwise_filter = false) + : computation_(computation), + filter_expansion_(!canonicalize_depthwise_filter) {} // Current HloComputation instance the ConvolutionVisitor is traversing. HloComputation* computation_; // Whether rewrite has occurred. bool changed_ = false; + + // Whether filter expansion is required. + bool filter_expansion_; }; -bool ConvolutionVisitor::Run(HloComputation* computation) { - ConvolutionVisitor visitor(computation); +bool ConvolutionVisitor::Run(HloComputation* computation, + bool canonicalize_depthwise_filter) { + ConvolutionVisitor visitor(computation, canonicalize_depthwise_filter); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -135,16 +142,16 @@ std::vector GetMaskIds(int64 group_size, int64 group_count) { // Finally we use the Eq op of these two broadcasted constants and get the // desired mask. HloInstruction* GetExpandedFilterMask( - const Shape& filter_shape, int64 input_feature_dim, - int64 output_feature_dim, int64 group_count, + const Shape& filter_shape, int64 kernel_input_feature_dim, + int64 kernel_output_feature_dim, int64 group_count, const std::function)>& add_instruction) { Shape expanded_filter_shape = - ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim); Shape mask_shape = ShapeUtil::MakeShape( S32, AsInt64Slice(expanded_filter_shape.dimensions())); - int64 output_feature = filter_shape.dimensions(output_feature_dim); - int64 group_size = filter_shape.dimensions(input_feature_dim); + int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim); + int64 group_size = filter_shape.dimensions(kernel_input_feature_dim); // Create a 'input_feature' sized linspace and 'output_feature' sized linspace // that will be broadcasted into perpendicular dimensions and compared. @@ -152,15 +159,14 @@ HloInstruction* GetExpandedFilterMask( GetMaskIds(group_size, group_count); const std::vector output_feature_filter_mask = GetMaskIds(output_feature / group_count, group_count); - auto mask1 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(input_feature_filter_mask))); - auto broadcasted_mask1 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask1, {kernel_input_feature_dim})); auto mask2 = add_instruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1(output_feature_filter_mask))); - auto broadcasted_mask2 = add_instruction( - HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast( + mask_shape, mask2, {kernel_output_feature_dim})); // Compare the broadcasted output feature linspace to the input feature // linspace to create a diagonal predicate. @@ -182,51 +188,203 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { }; auto dim_numbers = convolution->convolution_dimension_numbers(); - int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); - int64 group_size = filter->shape().dimensions(input_feature_dim); - int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); - auto expanded_filter_shape = - ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); - HloInstruction* filter_mask = GetExpandedFilterMask( - filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(kernel_input_feature_dim); + int64 kernel_output_feature_dim = + dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count, + kernel_input_feature_dim); + HloInstruction* filter_mask = + GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim, + kernel_output_feature_dim, group_count, add); HloInstruction* expanded_filter; - // We want to repeat 'filter' in the 'input_feature_dim' dimension - // 'group_count' times. + if (group_size == 1) { + bool depthwise_separable = + (group_count == filter->shape().dimensions(kernel_output_feature_dim)); + // If the code generator handles depthwise separable convolutions + // inherently, then no filter expansion is needed. + if (!filter_expansion_ && depthwise_separable) { + return Status::OK(); + } + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. Shape reshaped_filter_shape = - ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape()); auto reshaped_filter = add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); std::vector broadcast_dims; for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { - if (i == input_feature_dim) { + if (i == kernel_input_feature_dim) { continue; } broadcast_dims.push_back(i); } expanded_filter = add(HloInstruction::CreateBroadcast( expanded_filter_shape, reshaped_filter, broadcast_dims)); + + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add(HloInstruction::CreateTernary( + expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter, + zero_filter)); + + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); } else { - // We could possibly also use reshape, broadcast, reshape instead of concat - // here, but it would require more complex code, and for depthwise - // convolution we would never end up in this branch. - std::vector concat_operands(group_count, filter); - expanded_filter = add(HloInstruction::CreateConcatenate( - expanded_filter_shape, concat_operands, input_feature_dim)); + int64 activation_input_feature_dim = dim_numbers.input_feature_dimension(); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + + // If group_count == output_feature, then we map those grouped convolutions + // onto depthwise convolution. This is done by adding an additional spatial + // dimension to the activations, kernel, and the output. + // E.g., we would turn + // [2, 12]{B, IF} conv [3, 4]{IF, OF} into + // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the + // additional spatial dimension. The generated convolution output will be + // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}. + + if (group_count == output_feature && !filter_expansion_) { + auto filter = convolution->mutable_operand(1); + auto activation = convolution->mutable_operand(0); + + // Add spatial dimension to the activation, and reshape. + Shape reshaped_activation_shape = activation->shape(); + ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape); + + int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1; + + reshaped_activation_shape.set_dimensions(activation_input_feature_dim, + group_count); + activation = add( + HloInstruction::CreateReshape(reshaped_activation_shape, activation)); + + // Add spatial dimension to the filter, and reshape. + Shape reshaped_filter_shape = filter->shape(); + ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape); + + filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + + Shape new_output_shape = convolution->shape(); + ShapeUtil::AppendMajorDimension(1, &new_output_shape); + + // Edit convolution dimension numbers. Note that kernel_input_feature_dim + // now becomes a spatial dimension, and the newly added dimension of size + // 1 is the new kernel_input_feature_dim. + dim_numbers.add_input_spatial_dimensions(new_spatial_dim); + dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim); + dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim); + dim_numbers.add_output_spatial_dimensions(new_spatial_dim); + + // Add window for the new spatial dimension. + Window new_window = convolution->window(); + auto* dim = new_window.add_dimensions(); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_stride(1); + dim->set_size(group_size); + + auto new_convolution = add(HloInstruction::CreateConvolve( + new_output_shape, activation, filter, group_count, new_window, + dim_numbers, convolution->precision_config())); + + // Delete the extra spatial dimension, and reshape. + Shape reshaped_convolution_shape = + ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape()); + auto reshaped_convolution = HloInstruction::CreateReshape( + reshaped_convolution_shape, new_convolution); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshaped_convolution))); + + } else { + // The filter expansion mechanism adds zeroes in the kernel. + // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask + // would look like (IF on the Y-axis, OF on the X-axis) + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 1 1 1 1 0 0 0 0 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 1 1 1 1 0 0 0 0 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // 0 0 0 0 0 0 0 0 1 1 1 1 + // + // Instead of convolving the above with the input, we instead slice the + // kernel into three kernels, each containing islands of 1s from the + // filter above. We also slice the activations in the IF dimension with + // each slice of size = group_size. For each slice, we perform + // convolutions, and concatenate the generated outputs in the output OF + // dimension. + + std::vector sliced_convolutions; + auto activation = convolution->mutable_operand(0); + std::vector slice_strides(filter->shape().dimensions_size(), 1); + std::vector filter_slice_starts(filter->shape().dimensions_size(), + 0); + std::vector filter_slice_limits( + filter->shape().dimensions().begin(), + filter->shape().dimensions().end()); + std::vector activation_slice_starts( + activation->shape().dimensions_size(), 0); + std::vector activation_slice_limits( + activation->shape().dimensions().begin(), + activation->shape().dimensions().end()); + + int64 output_feature = + filter->shape().dimensions(kernel_output_feature_dim); + auto output_feature_dim = dim_numbers.output_feature_dimension(); + int64 filter_slice_width = output_feature / group_count; + + int64 activation_input_feature_dim = + dim_numbers.input_feature_dimension(); + + for (int64 i = 0; i < group_count; i++) { + filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width; + filter_slice_limits[kernel_output_feature_dim] = + (i + 1) * filter_slice_width; + auto filter_sliced_shape = filter->shape(); + filter_sliced_shape.set_dimensions(kernel_output_feature_dim, + filter_slice_width); + auto filter_slice = add(HloInstruction::CreateSlice( + filter_sliced_shape, filter, filter_slice_starts, + filter_slice_limits, slice_strides)); + + activation_slice_starts[activation_input_feature_dim] = i * group_size; + activation_slice_limits[activation_input_feature_dim] = + (i + 1) * group_size; + auto activation_sliced_shape = activation->shape(); + activation_sliced_shape.set_dimensions(activation_input_feature_dim, + group_size); + auto activation_slice = add(HloInstruction::CreateSlice( + activation_sliced_shape, activation, activation_slice_starts, + activation_slice_limits, slice_strides)); + + auto conv_slice_shape = convolution->shape(); + conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width); + + auto new_convolution = add(HloInstruction::CreateConvolve( + conv_slice_shape, activation_slice, filter_slice, + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config())); + + sliced_convolutions.push_back(new_convolution); + } + + auto new_conv = HloInstruction::CreateConcatenate( + convolution->shape(), sliced_convolutions, output_feature_dim); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_conv))); + } } - auto zero = add(HloInstruction::CreateConstant( - LiteralUtil::Zero(expanded_filter_shape.element_type()))); - auto zero_filter = - add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); - auto new_filter = add( - HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, - filter_mask, expanded_filter, zero_filter)); - auto new_convolution = HloInstruction::CreateConvolve( - convolution->shape(), convolution->mutable_operand(0), new_filter, - /*feature_group_count=*/1, convolution->window(), dim_numbers, - convolution->precision_config()); - TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( - convolution, std::move(new_convolution))); + return Status::OK(); } @@ -237,7 +395,7 @@ StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (ConvolutionVisitor::Run(comp)) { + if (ConvolutionVisitor::Run(comp, filter_expansion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index ce0138e56fbd51daaf5d3ac329ccbe31a9fdbde7..cb6bc04c00a2ff10f970da2a07fb540a561dad5a 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -27,7 +27,8 @@ namespace xla { // convolutions with feature_group_count = 1. class ConvolutionFeatureGroupConverter : public HloModulePass { public: - ConvolutionFeatureGroupConverter() {} + ConvolutionFeatureGroupConverter(bool canonicalize_depthwise_filter = false) + : filter_expansion_(canonicalize_depthwise_filter) {} absl::string_view name() const override { return "convolution-feature-group-converter"; @@ -36,6 +37,9 @@ class ConvolutionFeatureGroupConverter : public HloModulePass { // Run convolution rewriting on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + // Tells whether filter expansion is required. + bool filter_expansion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc index 28373ebf636c7b6b3059dcf6cd931901ebc87fc2..e6bf2143a21bd5001d3530fe8727c88504be1d43 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -82,18 +82,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 ConvolutionFeatureGroupConverter converter; ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - // Make sure the convolution is converted to one with feature_group_count = 1. - EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - EXPECT_EQ(root->feature_group_count(), 1); - // Verify that the filter operand has been replaced. - EXPECT_THAT(root->operand(1), - op::Select(op::Eq(op::Broadcast(op::Constant()), - op::Broadcast(op::Constant())), - // We expect to see Concatenate here instead of - // Broadcast, because feature_group_count < input - // feature dimension. - op::Concatenate(op::Parameter(), op::Parameter()), - op::Broadcast(op::Constant()))); + // Make sure the convolution is replaced with a concatenate. + EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate); + // And the operands of the concatenate are convolutions, each with a feature + // group count = 1. + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->operand(0)->feature_group_count(), 1); + EXPECT_EQ(root->operand(1)->feature_group_count(), 1); } } // namespace diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index f35324aa35370b592871749cba9fc2f66bea9219..df6059663876dfde71f4c75d3931b3d2de72c1df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -40,10 +40,12 @@ namespace { using absl::StrAppend; -bool IsEntryParameterValue(const HloValue& value) { +bool IsReadonlyEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); + computation == computation->parent()->entry_computation() && + !computation->parent()->input_output_alias_config().ParameterHasAlias( + value.defining_instruction()->parameter_number(), value.index()); } bool IsConstantValue(const HloValue& value) { @@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) { } bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); + return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); } // Data structure describing the action which should be taken on parts of a @@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, bool ShouldCopyRootValue(const HloValue& value, const SpecialCaseCopyPolicy& policy) { if (policy.copy_parameters_and_constants) { - return IsConstantValue(value) || - value.defining_instruction()->opcode() == HloOpcode::kParameter; + return ValueIsReadOnly(value); } return false; } @@ -332,6 +333,88 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Conservatively adds copies before root instruction of entry computation and +// each aliased parameter to resolve interference of aliased input and output +// buffer. We later rely on the CopyRemover to drop the unnecessary ones. +Status AddCopiesForAliasedInputOutputs(HloModule* module) { + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + + ShapeTree output_indices_to_copy(root->shape()); + std::vector>> copied_parameters( + entry->num_parameters()); + bool has_alias = false; + for (auto* param : entry->parameter_instructions()) { + bool param_has_alias = false; + ShapeTree param_indices_to_copy(param->shape()); + + module->input_output_alias_config().ForEachAlias( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + if (param_number == param->parameter_number()) { + param_has_alias = true; + *(param_indices_to_copy.mutable_element(param_index)) = true; + *(output_indices_to_copy.mutable_element(output_index)) = true; + } + }); + + if (!param_has_alias) { + continue; + } + + TF_RET_CHECK(param->parameter_number() < entry->num_parameters()); + TF_RET_CHECK(!copied_parameters[param->parameter_number()]); + + has_alias = true; + // Store a snapshot of users before DeepCopyInstruction, as + // DeepCopyInstruction introduces new users of the instruction. + std::vector users = param->users(); + ShapeTree param_copy_tree(param->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * copied, + entry->DeepCopyInstruction( + param, ¶m_indices_to_copy, ¶m_copy_tree)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); + } + + copied_parameters[param->parameter_number()] = param_copy_tree; + } + + if (!has_alias) { + return Status::OK(); + } + + // Add copies before root instruction. + ShapeTree output_copy_tree(root->shape(), + /*init_value=*/nullptr); + + TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, + root->parent()->DeepCopyInstruction( + root, &output_indices_to_copy, &output_copy_tree)); + + // Add control dependencies between the input/output copies. + TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& input_index) -> Status { + if (!copied_parameters[param_number]) { + return Status::OK(); + } + HloInstruction* from = + copied_parameters[param_number]->element(input_index); + HloInstruction* to = output_copy_tree.element(output_index); + + TF_RET_CHECK(from != nullptr); + TF_RET_CHECK(to != nullptr); + TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); + return Status::OK(); + })); + + entry->set_root_instruction(root_copied); + + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -359,7 +442,6 @@ class CopyRemover { const HloOrdering& ordering, HloModule* module) : module_(module), alias_analysis_(alias_analysis), - ordering_(ordering), buffer_value_tracker_(*module, alias_analysis, ordering) {} // Try to elide the given copy. The copy is elided if the instruction is not @@ -920,7 +1002,6 @@ class CopyRemover { HloModule* module_; const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; // Object tracking the HLO values contained in each HLO buffer. BufferValueTracker buffer_value_tracker_; @@ -953,6 +1034,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } } } + + TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index c097089e30d59936a32f69c49123c398f1611ea3..8866b5050bf1e7419dda6496ea95d034178d25d8 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -94,10 +94,12 @@ class CopyInsertion : public HloModulePass { Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, HloModule* module); - private: + protected: // Override which requires the caller to pass in a call graph. - Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + virtual Status AddSpecialCaseCopies(const CallGraph& call_graph, + HloModule* module); + private: Status AddCopiesToResolveInterference(HloModule* module); // Backend specific function that decides whether a fusion can share buffer diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 892d0d7b547aaf1e7f1c55e4163d1e1fd9518def..e4e9d7ba05c115be9dd0eb53ebd7de208d514efb 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,7 +94,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -114,7 +114,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -127,7 +127,7 @@ TEST_F(CopyInsertionTest, SingleConstant) { TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { // Verify that kCopy instructions which change layout and exist before // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = @@ -181,7 +181,7 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); InsertCopies(module.get()); @@ -217,7 +217,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloInstruction* old_root = module->entry_computation()->root_instruction(); @@ -238,7 +238,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -261,7 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); @@ -283,7 +283,7 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); @@ -310,7 +310,7 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(HloOpcode::kParameter, @@ -351,7 +351,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -388,7 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(select->shape(), {0}), select, 0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gte, module->entry_computation()->root_instruction()); @@ -403,7 +403,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { class WhileCopyInsertionTest : public CopyInsertionTest { protected: - WhileCopyInsertionTest() : module_(CreateNewModule()) {} + WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {} // Builds a While condition computation which reads the induction variable // from the tuple parameter, and returns a predicate indicating whether this @@ -1295,7 +1295,7 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { TEST_F(CopyInsertionTest, SwizzlingWhile) { // Test a while instruction with a body which permutes its tuple parameter // elements. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1351,13 +1351,225 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) { EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); } +TEST_F(CopyInsertionTest, CrossingParameters) { + // Test a case where two parameters' dataflow cross with each other while + // input and output are aliased with same index: + // + // (p0 , p1) + // | \ /| + // | \ / | + // alias X alias + // | / \ | + // | / \| + // (p1 , p0) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 4); +} + +TEST_F(CopyInsertionTest, ParametersAliasing) { + // Test a case where two parameters' dataflow don't interfere with each other + // while aliased. + // + // (p0 , p1) + // | | + // | | + // alias alias + // | | + // | | + // (p0 , p1) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterWithNoAliasing) { + // Test a case where no parameter is aliased with result. In this case, copy + // should be added + // + // (p0 , p1) + // | | + // | | + // | | + // | | + // | | + // (p0 , p1) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(param, 0)), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // (p0 , p1) + // | | + // | | + // alias | + // | | + // | | + // (p0 , p1) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(param, 0), + op::Copy(op::GetTupleElement(param, 1)))); + + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | | | + // +-- (p0 , p1) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) { + // Test a case where one parameter is aliased with result while another one + // isn't. + // + // +-- (p0 , p1) + // | | | + // | | | + // alias Negate Negate + // | | | + // | Add----+ + // | | | + // +-- (p0 , p1) + auto module = CreateNewVerifiedModule(); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, negate0, negate1)); + builder.AddInstruction(HloInstruction::CreateTuple({add, negate1})); + module->AddEntryComputation(builder.Build()); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 0); +} + TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { // Test a while instruction with a body which permutes its tuple parameter // elements and applies one operation to one of the elements. The addition of // the operation (instruction) on the element makes the live range of the // respective input and output elements different than if the instruction were // not there (as in the SwizzlingWhile test above). - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1420,7 +1632,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { // the while body is a single constant (both loop state elements are the same // constant). This means no copies are necessary because both loop state // elements are the same so interchanging them is a no-op. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape loop_state_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); @@ -1481,7 +1693,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { const Shape loop_state_shape = ShapeUtil::MakeTupleShape( {element_shape, element_shape, element_shape, element_shape}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, element_shape, "param_0")); @@ -1571,7 +1783,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) { TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). The body constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -1684,7 +1896,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); @@ -1724,7 +1936,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { tensorflow::testing::StopTiming(); for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); @@ -1791,7 +2003,7 @@ std::unique_ptr MakeBenchmarkWhileBody( void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { tensorflow::testing::StopTiming(); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); CopyInsertion copy_insertion; const Shape element_shape = ShapeUtil::MakeShape(F32, {}); std::vector tuple_params(num_tuple_inputs); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 58abb330a6e31e9b7a8081cd7964cf89a5b64a09..ce4c2a9cc69240b9565b35a3f2504d7fc9373917 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -51,6 +51,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -95,6 +96,7 @@ cc_library( "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -823,7 +825,6 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -845,7 +846,6 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -886,7 +886,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -960,17 +959,16 @@ tf_cc_test( srcs = ["cpu_copy_insertion_test.cc"], deps = [ ":cpu_copy_insertion", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -996,7 +994,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 73b03440cbb936017257b8a92f16dcc25d41e21c..796a7cf94d02b0ad42366387a9d3f8d589b8840a 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -61,19 +61,6 @@ Disabling these as a starting point. // TODO(b/64227304) Creating a custom pass pipeline will replace this. namespace { -class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { - public: - FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) - : llvm::legacy::FunctionPassManager(m), - disable_expensive_passes_(disable_expensive_passes) {} - void add(llvm::Pass* p) override { - llvm::legacy::FunctionPassManager::add(p); - } - - private: - bool disable_expensive_passes_; -}; - class FilteredPassManager : public llvm::legacy::PassManager { public: explicit FilteredPassManager(bool disable_expensive_passes) @@ -96,8 +83,7 @@ class FilteredPassManager : public llvm::legacy::PassManager { std::unique_ptr CompilerFunctor::operator()( llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); - FilteredFunctionPassManager function_passes(&module, - disable_expensive_passes_); + llvm::legacy::FunctionPassManager function_passes(&module); VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 2083f440fdd971db1b675d005664d25e6de53dbe..c58175428fea6a2d38253c35de598b99a4281bf1 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloVerifiedTestBase { +class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -87,7 +87,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); @@ -96,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -150,7 +150,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { input, kernel, /*feature_group_count=*/1, conv_window_, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( @@ -158,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5834f672851f5379c56f6479fd463464c6f3791c..6374822c81bf42fd12829f57cf93c19457128219 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -76,6 +76,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -268,10 +269,11 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - pass.AddPass( - /*is_layout_sensitive=*/false, - [](const Shape&, const Shape&) { return false; }, - /*enable_dot_strength_reduction=*/false); + pipeline.AddPass(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_enable_dot_strength_reduction(false); + pass.AddPass(options); pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO @@ -327,12 +329,18 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass>( "simplification after layout assignement"); - pass.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); - pass.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }, - /*enable_dot_strength_reduction=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_dot_strength_reduction(false); + pass.AddPass>(options); pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } @@ -497,8 +505,8 @@ Status CreateHloProfilingArtifacts( HloCostAnalysis cost_analysis(shape_size_bytes); TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis)); - *hlo_profile_printer_data = - CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis); + *hlo_profile_printer_data = CreateHloProfilePrinterData( + **hlo_profile_index_map, cost_analysis, entry_computation.name()); *computation_to_profile_idx = (*hlo_profile_index_map)->computation_to_profile_idx(); @@ -582,9 +590,9 @@ StatusOr> CpuCompiler::RunBackend( // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(module.get(), BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( @@ -671,9 +679,12 @@ StatusOr> CpuCompiler::RunBackend( } StatusOr>> -CpuCompiler::CompileAheadOfTime(std::vector> modules, +CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(!modules.empty()); + TF_RET_CHECK(!module_group->empty()); + std::vector> modules = + module_group->ConsumeModules(); + std::call_once(llvm_command_line_options_initialized, &llvm_ir::InitializeLLVMCommandLineOptions, modules[0]->config()); @@ -771,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction())); + ScheduleModule(module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index f2af923782df268e3e6da3895ec35579ab6aa51f..c67307548dda731f8fa56b8e6790e7e83f587113 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -142,7 +142,7 @@ class CpuCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& options) override; se::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index c9fb34be1cd582c71618c770c892058c233c571a..c085f85fb73e98e4c7ba15af8db8bb19c2499f5f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloVerifiedTestBase { +class CpuCopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -65,7 +65,7 @@ class CpuCopyInsertionTest : public HloVerifiedTestBase { TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { // Test a while body and condition which are each simply a constant (root of // computation is a constant). Each constant should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param_0 = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*module), 3); @@ -103,7 +103,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { // Test a kCall instruction which calls a computation which produces a three // element tuple: one is a constant, one is a parameter, and one is produced // in the computation. The constant and parameter should be copied. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module); + InsertCopies(module.get()); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 29abf38e439d919ff93629ed992cb3ff93a929bd..818b2b0d0db2893e11fa46c7867e6c74bbbb6905 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -51,8 +51,7 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 3c3c047bfe8ee0d1ad90ede2432a86264f47870b..3b91b15ba9b5603b50f78f489e9a3fdad354c083 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -49,7 +49,7 @@ class CpuExecutable : public Executable { public: CpuExecutable(std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, const string& entry_function_name, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index be1208fb2df2a1a11a093810b5f6c2a83f468062..9cbfb88834bf51f4df54e97efe6cd7bf88b12334 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloVerifiedTestBase { +class CpuHloSupportCheckerTest : public HloTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(CpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -57,10 +57,13 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module).status(); + Status status = checker().Run(module.get()).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f9cd61bea3dc86cadff99d4a90eca44c16520823..6f79ad7c1468f27c74d84770ec6358fbcd1c1f09 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) { (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } +bool HasExactlyOneUse(const HloInstruction& hlo_instr) { + return hlo_instr.user_count() == 1 && + absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; +} + bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && - producer->user_count() == 1; + HasExactlyOneUse(*producer) == 1; } bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 7d99b914d4f5e5d27722bcd098d2ae0c54a36a23..527df0bd1c23bba74f32226e5622fed32f7dcf84 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -58,7 +58,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -98,7 +98,7 @@ TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -119,7 +119,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -138,7 +138,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -157,7 +157,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* dot = builder.AddInstruction( MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); @@ -321,7 +321,7 @@ TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -350,7 +350,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -370,7 +370,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, broadcast1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -392,7 +392,7 @@ TEST_F(OpcodeFusionTest, DynamicSlice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, dynamic_slice2)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -410,7 +410,7 @@ TEST_F(OpcodeFusionTest, Exponential_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -429,7 +429,7 @@ TEST_F(OpcodeFusionTest, Reshape_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -447,7 +447,7 @@ TEST_F(OpcodeFusionTest, Reverse_Negate) { builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -466,7 +466,7 @@ TEST_F(OpcodeFusionTest, Slice_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {2}), HloOpcode::kNegate, slice1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -489,7 +489,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { builder.AddInstruction(HloInstruction::CreateUnary( result_shape, HloOpcode::kNegate, transpose2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( @@ -498,7 +498,7 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { } TEST_F(OpcodeFusionTest, UnaryMapOfExp) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -517,7 +517,7 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) { } TEST_F(OpcodeFusionTest, BinaryMapOfExps) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); @@ -542,7 +542,7 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) { } TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); @@ -573,7 +573,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { } TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); @@ -641,7 +641,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { builder.AddInstruction( HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -670,7 +670,7 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { builder.AddInstruction(HloInstruction::CreateBinary( large_shape, HloOpcode::kAdd, small_exp, large_param)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto did_fusion = CpuInstructionFusion().Run(module.get()); @@ -712,7 +712,7 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, } TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -725,7 +725,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/false); @@ -738,7 +738,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/19, /*add_extra_use_for_dot=*/false); @@ -751,7 +751,7 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { } TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, /*k=*/50, /*n=*/1, /*add_extra_use_for_dot=*/true); @@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + struct GatherLoopFusionTestSpec { string test_name; string hlo_computation_text; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 97659b88a7974d7caf91ab0d4741f3635e4dae4a..6c61b64758ede160e2d50e4429590a789ec253c3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -73,7 +73,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -114,7 +114,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -158,7 +158,7 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -192,7 +192,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -232,7 +232,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_result = builder.AddInstruction( CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); @@ -353,7 +353,7 @@ static void AssertCorrectLayoutForDotOutputFusion( } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -365,7 +365,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, @@ -377,7 +377,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -389,7 +389,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, @@ -401,7 +401,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, @@ -413,7 +413,7 @@ TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { } TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); TF_ASSERT_OK_AND_ASSIGN( DotOutputFusionLayoutAssignmentResult layout_assignment_result, RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b8ace5702688096822573c7afae234cbcbe77b28..92debb83e33b1400a59e5eef0f90971392ab7b22 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,7 +22,6 @@ limitations under the License. namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; -const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 1cc2844470376ceb61601f6d1361def84eac5b45..1457582ac19c27e5c3150b4667e6af505345a6bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" @@ -183,7 +183,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. absl::Span dimensions( - tensorflow::bit_cast(literal_shape.dimensions().data()), + absl::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); TF_ASSIGN_OR_RETURN( Shape received_shape, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 99fa707c959854e50c6d954fe92b87e93e267dc6..97f9b85a606e140fd7f3b1e3ecfb0dd5ba289f03 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1546,10 +1546,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } -// Return whether the given shape is a matrix with no padding. -static bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} +// Return whether the given shape is rank 2. +static bool IsRank2(const Shape& shape) { return ShapeUtil::Rank(shape) == 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. @@ -1565,8 +1563,7 @@ static bool AreValidGemmShapes( return false; } - if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape))) { + if (!(IsRank2(lhs_shape) && IsRank2(rhs_shape) && IsRank2(output_shape))) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a70abb117acd2917e7273921e1919b0e03b6cd63..4032c2da2f33ee61da8771ae6225a14172cbe6e8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" @@ -110,7 +111,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); @@ -139,7 +140,7 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { @@ -493,53 +494,44 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction* sort) { +Status IrEmitter::HandleSort(HloInstruction* hlo) { + const HloSortInstruction* sort = Cast(hlo); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); - auto keys = sort->operand(0); - auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; - ShapeIndex keys_shape_index({}); - ShapeIndex values_shape_index({}); - if (values != nullptr) { - keys_shape_index = ShapeIndex({0}); - values_shape_index = ShapeIndex({1}); - } - auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); - auto keys_destination_address = - EmitBufferPointer(keys_destination, keys->shape()); - auto values_destination = GetAllocationSlice(*sort, values_shape_index); - llvm::Value* values_destination_address = nullptr; - - // The sort is implemented in-place, therefore we first copy the operand - // buffer to the output buffer if they are not the same. - if (keys_destination != GetAllocationSlice(*keys)) { - int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); - auto source_buffer = GetEmittedValueFor(keys); - int64 keys_size = ByteSizeOf(keys->shape()); - MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, - source_buffer, - /*SrcAlign=*/primitive_type_size, keys_size); - } - if (values != nullptr) { - values_destination_address = - EmitBufferPointer(values_destination, values->shape()); - if (values_destination != GetAllocationSlice(*values)) { + Shape keys_shape = sort->keys()->shape(); + std::vector destination_addresses(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->values_count() > 0 ? ShapeIndex({i}) : ShapeIndex({}); + const HloInstruction* operand = sort->operand(i); + // We assume that the layout of all involved operands and outputs is the + // same. + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(keys_shape, operand->shape())); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + auto destination_buffer = GetAllocationSlice(*sort, shape_index); + destination_addresses[i] = + EmitBufferPointer(destination_buffer, operand->shape()); + auto source_address = GetAllocationSlice(*operand); + if (destination_buffer != source_address) { int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); - auto source_buffer = GetEmittedValueFor(values); - int64 values_size = ByteSizeOf(values->shape()); - MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type()); + auto source_buffer = GetEmittedValueFor(operand); + int64 size = ByteSizeOf(operand->shape()); + MemCpy(destination_addresses[i], /*DstAlign=*/primitive_type_size, source_buffer, - /*SrcAlign=*/primitive_type_size, values_size); + /*SrcAlign=*/primitive_type_size, size); } } // Normalize the shape and the dimension to sort. Shape normalized_keys_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - keys->shape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(keys_shape); int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( - keys->shape().layout())[sort->dimensions(0)]; + keys_shape.layout())[sort->sort_dimension()]; int64 sort_dimension_elements = normalized_keys_shape.dimensions(physical_dimension_to_sort); @@ -553,7 +545,7 @@ Status IrEmitter::HandleSort(HloInstruction* sort) { lower_dimensions *= normalized_keys_shape.dimensions(i); } - PrimitiveType keys_type = keys->shape().element_type(); + PrimitiveType keys_type = keys_shape.element_type(); const char* fn_name = nullptr; llvm::Type* keys_native_type = nullptr; switch (keys_type) { @@ -614,28 +606,48 @@ Status IrEmitter::HandleSort(HloInstruction* sort) { llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( b_.getVoidTy(), {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), - b_.getInt8PtrTy(), b_.getInt32Ty()}, + b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(), + b_.getInt32Ty()->getPointerTo()}, /*isVarArg=*/false); auto* key_value_sort_func = llvm::cast( module_->getOrInsertFunction(fn_name, key_value_sort_type)); key_value_sort_func->setCallingConv(llvm::CallingConv::C); key_value_sort_func->setDoesNotThrow(); - key_value_sort_func->setOnlyAccessesArgMemory(); + llvm::Value* values; + llvm::Value* sizes; + if (sort->values_count() == 0) { + values = llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()); + sizes = llvm::Constant::getNullValue(b_.getInt32Ty()->getPointerTo()); + } else { + values = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt8PtrTy(), b_.getInt32(sort->values_count()), + "cc_values_alloca", &b_); + sizes = llvm_ir::EmitAllocaAtFunctionEntryWithCount( + b_.getInt32Ty(), b_.getInt32(sort->values_count()), "cc_sizes_alloca", + &b_); + for (int64 i = 0; i < sort->values_count(); ++i) { + llvm::Value* value_as_i8ptr = + PointerCast(destination_addresses[i + 1], b_.getInt8PtrTy()); + llvm::Value* slot_in_values_alloca = + ConstInBoundsGEP1_32(b_.getInt8PtrTy(), values, i); + Store(value_as_i8ptr, slot_in_values_alloca); + llvm::Value* slot_in_sizes_alloca = + ConstInBoundsGEP1_32(b_.getInt32Ty(), sizes, i); + llvm::Value* size = b_.getInt32(ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i + 1)->shape().element_type())); + Store(size, slot_in_sizes_alloca); + } + } + Call(key_value_sort_func, - {PointerCast(keys_destination_address, keys_native_type), + {PointerCast(destination_addresses[0], keys_native_type), b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), - b_.getInt64(lower_dimensions), - values != nullptr - ? PointerCast(values_destination_address, b_.getInt8PtrTy()) - : llvm::Constant::getNullValue(b_.getInt8PtrTy()), - b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( - values->shape().element_type()) - : 0)}); - - if (values != nullptr) { - llvm_ir::EmitTuple(GetIrArrayFor(sort), - {keys_destination_address, values_destination_address}, - &b_, module_); + b_.getInt64(lower_dimensions), values, + b_.getInt32(sort->values_count()), sizes}); + + if (sort->values_count() > 0) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), destination_addresses, &b_, + module_); } return Status::OK(); } @@ -688,8 +700,25 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to @@ -728,12 +757,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -1356,33 +1379,6 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -// Fills up the free variables in 'index_with_free_var' with values from -// 'filler_index'. The size of free variables must be the same as the -// size of 'filler_index'. -// -// This is often used after dimension reduction, where -// 'index_with_free_var' has one or more dimensions reduced, which serves as -// free variables (represented as nullptr). For example, if we have a 4 -// dimensional input and index for the dimension being reduced is -// 2 (third dimension), we will have an index like [i, j, NULL, k] -// after reduced dimension. -// -// Here we fill up that free variable by 'filler_index', which contains -// the value in the reduced dimension. -static llvm_ir::IrArray::Index FillReducedDimensionIndex( - llvm_ir::IrArray::Index index_with_free_var, - llvm_ir::IrArray::Index filler_index) { - llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); - - for (size_t i = 0; i < index_with_free_var.size(); ++i) { - if (index_with_free_var[i] == nullptr) { - index_with_free_var[i] = *it++; - } - } - CHECK(filler_index.end() == it); - return index_with_free_var; -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); return EmitTargetAddressForOp(parameter); @@ -1513,7 +1509,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMaximum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs, rhs}, {lhs->getType()}, b); @@ -1528,7 +1525,8 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( case HloOpcode::kMinimum: return [root_is_floating_point, root_is_signed]( - llvm::IRBuilder<>* b, llvm::Value* lhs, llvm::Value* rhs) { + llvm::IRBuilder<>* b, llvm::Value* lhs, + llvm::Value* rhs) -> llvm::Value* { if (root_is_floating_point) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs, rhs}, {lhs->getType()}, b); @@ -2169,30 +2167,22 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { return Status::OK(); } -// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself. -static const HloInstruction* StripTranspose(const HloInstruction& hlo) { - if (hlo.IsRank2Transpose()) { - return hlo.operand(0); - } - return &hlo; -} - Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - // Delegate to common implementation of fused in-place dynamic-update-slice. - auto operands = GetIrArraysForOperandsOf(fusion); return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( - fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, &b_); + fusion, GetGeneratorForOperandIrArrays(fusion), GetIrArrayFor(fusion), + &elemental_emitter, &b_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); auto operands = GetIrArraysForOperandsOf(fusion); - FusedIrEmitter fused_emitter(operands, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); @@ -2392,14 +2382,8 @@ StatusOr IrEmitter::EmitFastConcatenate( *failure_reason = "operand has mismatching layouts"; return false; } - if (LayoutUtil::IsPadded(op->shape())) { - *failure_reason = "operand has padded layout"; - return false; - } } - CHECK(!LayoutUtil::IsPadded(concatenate->shape())); - // We split the dimensions into three categories: the dimension over which we // are concatenating (concat_dim), the dimensions that are minor to it // (inner_dims) and the dimensions that are major to it (outer_dims). @@ -2581,10 +2565,17 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } -Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { - TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); +Status IrEmitter::HandleAfterAll(HloInstruction* after_all) { + TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0); // No code to generate, but we need to emit an address for book-keeping. - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all)); + return Status::OK(); +} + +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + emitted_value_[add_dependency] = + GetEmittedValueFor(add_dependency->operand(0)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 586f27b104ed706a3b128903c6a90abbf3667e59..559a8162a2d53f28ea6817653503c216af90a610 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -59,6 +59,9 @@ namespace cpu { class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { public: + using GeneratorForOperandIrArrays = + std::function()>; + // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. @@ -98,7 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } @@ -156,7 +159,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -208,6 +212,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector GetIrArraysForOperandsOf( const HloInstruction* hlo); + GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( + HloInstruction* unnested_hlo) { + return [=]() { return GetIrArraysForOperandsOf(unnested_hlo); }; + } + // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm_ir::IrArray* array) { @@ -459,9 +468,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, // profiling a computation. class ProfilingState { public: - ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} - ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) - : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false) {} + explicit ProfilingState(bool use_rdtscp) : use_rdtscp_(use_rdtscp) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* b, HloInstruction* hlo); @@ -486,9 +494,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, // intrinsic? bool use_rdtscp_; - // The argument which corresponds to the profile counter buffer. - llvm::Value* prof_counters_; - // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index cef5e57b0b12b7ae93af0d2508b2b9d6a592d390..f9722ffadac801521ddcbb568dd4435fd02e951b 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -22,7 +22,6 @@ limitations under the License. #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index fad76338a57cd9eb21d9469ca8552efa8ea0129b..f0b65046c14ccec5336abf7c4d05d1d755f783bd 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class ParallelTaskAssignmentTest : public HloVerifiedTestBase { +class ParallelTaskAssignmentTest : public HloTestBase { protected: const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = cpu::CpuExecutable::ShapeSizeBytes; @@ -35,7 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { + : HloTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -57,8 +57,9 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -84,8 +85,9 @@ TEST_F(ParallelTaskAssignmentTest, } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -100,8 +102,9 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } @@ -116,8 +119,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index e0e7deb98e579c090c8fae320a3ba8a3ce8dbe5f..722aa3120ef4d8c957873ac58c361f19632dde1f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -41,66 +42,72 @@ void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { std::sort(row_to_sort, row_to_sort + num_elements); } -// For floating point numbers, we want a total order comparator. -NaN and NaN -// should appear at the beginning and end of the ordering, and -0.0 should -// appear before 0.0. Also we want to have a stable sort, so if the keys are the -// same, we compare the index values. -template -bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { - bool lhs_is_negative = std::signbit(lhs); - bool rhs_is_negative = std::signbit(rhs); - // If the signs are different, we can just compare the signs. - if (lhs_is_negative != rhs_is_negative) { - return lhs_is_negative && !rhs_is_negative; - } - bool lhs_nan = std::isnan(lhs); - bool rhs_nan = std::isnan(rhs); - // Exactly one number is nan? - if (lhs_nan != rhs_nan) { - if (lhs_nan) { - return lhs_is_negative; - } - return !rhs_is_negative; +// We would like a total order of floating point numbers so that the +// sort has a predictable behavior in the presence of NaNs. Rather +// than using floating point comparison, we use the following trick: +// If f is a float, and +// x = bit_cast(f); +// y = x < 0 ? 0x7FFFFFFF - x : x; +// then y is ordered as an int32 such that finite values have the +// obvious order, -0 is ordered before 0, and -NaN and NaN appear at +// the beginning and end of the ordering. +template +CastType Convert(KeyType value) { + CastType casted_value; + memcpy(&casted_value, &value, sizeof(CastType)); + if (casted_value < 0) { + return static_cast(std::numeric_limits::max()) - + casted_value; } - if (lhs != rhs) { - return lhs < rhs; - } - return lhs_index < rhs_index; + return casted_value; +} + +template +bool LessThan(KeyType lhs, KeyType rhs) { + return Convert(lhs) < + Convert(rhs); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan(lhs.first, rhs.first); + }); } template <> void KeyValueSort(std::pair* row_to_sort, int64 num_elements) { - std::sort(row_to_sort, row_to_sort + num_elements, - [](const std::pair& lhs, - const std::pair& rhs) -> bool { - return LessThan( - Eigen::half_impl::half_to_float(lhs.first), lhs.second, - Eigen::half_impl::half_to_float(rhs.first), rhs.second); - }); + std::stable_sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair& lhs, + const std::pair& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), + Eigen::half_impl::half_to_float(rhs.first)); + }); } template -void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char** values, + int32 values_count, + int32* values_primitive_type_size_in_bytes) { + // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT + // code, so msan can't tell they are initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); + TF_ANNOTATE_MEMORY_IS_INITIALIZED(values_primitive_type_size_in_bytes, + values_count * sizeof(int32)); + // High-level idea of the iteration/sorting logic: // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the // dimension to sort, c is the product of the more minor dimensions (set to 1 @@ -129,7 +136,7 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; // TODO(b/26783907): We could define a custom iterator class that references - // both arrays. Then we could avoid the intermediate copy. However this + // all arrays. Then we could avoid the intermediate copy. However this // would become more complicated, and it is not clear if the benefit is high // enough. for (int64 i = 0; i < sort_dimension_elements; ++i) { @@ -140,97 +147,109 @@ void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, for (int64 i = 0; i < sort_dimension_elements; ++i) { keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; } - if (values == nullptr) { - continue; - } // Reorder the values according to the order defined by the keys. - for (int64 i = 0; i < sort_dimension_elements; ++i) { - int64 memory_index = - (base_offset + row_to_sort[i].second * sort_dimension_offset) * - values_primitive_type_size_in_bytes; - - reordered_values[i] = std::string(values + memory_index, - values_primitive_type_size_in_bytes); - } - for (int64 i = 0; i < sort_dimension_elements; ++i) { - int64 memory_index = (base_offset + i * sort_dimension_offset) * - values_primitive_type_size_in_bytes; - memcpy(values + memory_index, reordered_values[i].c_str(), - values_primitive_type_size_in_bytes); + for (int32 idx = 0; idx < values_count; ++idx) { + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes[idx]; + + reordered_values[i] = + std::string(values[idx] + memory_index, + values_primitive_type_size_in_bytes[idx]); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes[idx]; + memcpy(values[idx] + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes[idx]); + } } } } } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( - bool* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + bool* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( - int8* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + int8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( - uint8* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + uint8* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( - int16* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + int16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( - uint16* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + uint16* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( - Eigen::half* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + Eigen::half* keys, int64 a, int64 b, int64 c, char** values, + int32 values_count, int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( - int32* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + int32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( - uint32* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + uint32* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( - float* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + float* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( - int64* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + int64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( - uint64* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + uint64* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( - double* keys, int64 a, int64 b, int64 c, char* values, - int32 values_primitive_type_size_in_bytes) { - KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); + double* keys, int64 a, int64 b, int64 c, char** values, int32 values_count, + int32* values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_count, + values_primitive_type_size_in_bytes); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h index 28e35e82c18cbf078f8a1e7f5b818bf839d3d3df..7821099386969e855ea1737cf53ef49c15c6e93b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -22,67 +22,75 @@ limitations under the License. extern "C" { // 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' -// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. -// If 'values' is not nullptr, the elements in 'values' are reordered in such a -// way that if the element at index 'i' in 'keys' was moved to index 'j', the -// element at index 'i' in 'values' is also moved to index 'j' (which means that -// the same elements correspond to each other as before). +// dimension of 'keys' is sorted into ascending order. If 'values_count' is <= +// 0, 'values' and 'values_primitive_type_size_in_bytes' can be nullptr. +// If 'values_count' > 0, they contain exactly 'values_count' many elements. +// Each element of 'values' also represents a 3-dimensional shape with +// dimensions [a, b, c], and the size of the primitive type of the i-th shape +// has exactly 'values_primitive_type_size_in_bytes[i]' bytes. The elements in +// each 'values' shape are reordered in such a way that if the element at index +// 'i' in 'keys' was moved to index 'j', the element at index 'i' in a 'values' +// shape is also moved to index 'j' (which means that the same elements +// correspond to each other as before). extern void __xla_cpu_runtime_KeyValueSortPRED( bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortS8( tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortU8( tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortS16( tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortU16( tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortF16( Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortS32( tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortU32( tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortF32( float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortS64( tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortU64( tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, - tensorflow::int64 c, char* values, - tensorflow::int32 values_primitive_type_size_in_bytes); + tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); extern void __xla_cpu_runtime_KeyValueSortF64( double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, - char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes); } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 1a3d82de954318368d61e3feeb0345dc592dcd8b..7d8e51f909e3db699b745f94a6c625407bc4a6e3 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloVerifiedTestBase { +class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloVerifiedTestBase { +class ShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { +class RandomShapePartitionIteratorTest : public HloTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 9ec0c8f65705db335379649def746921e6b05bea..efccadedf27181a4cddf4f1dc3610f7c6db1d821 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -108,15 +108,15 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, [](llvm::Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), - object_layer_(execution_session_, - [this](llvm::orc::VModuleKey) { - llvm::orc::RTDyldObjectLinkingLayer::Resources result; - result.MemMgr = - std::make_shared( - orc_jit_memory_mapper::GetInstance()); - result.Resolver = symbol_resolver_; - return result; - }), + object_layer_( + execution_session_, + [this](llvm::orc::VModuleKey) { + llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; + result.MemMgr = std::make_shared( + orc_jit_memory_mapper::GetInstance()); + result.Resolver = symbol_resolver_; + return result; + }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, @@ -128,8 +128,18 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, } llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + void* func_addr = nullptr; + if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { + // On Mac OS X, 'name' may have a leading underscore prefix, even though the + // registered name may not. + std::string stripped_name(name.begin() + 1, name.end()); + func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name); + } else { + func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + } + if (func_addr == nullptr) { + VLOG(2) << "Unable to resolve runtime symbol: " << name; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index d74b63fcf45bd70cd18ee41f1e9714ba6a222abd..78406ba143570183aea09d79db3f9b708c21bf70 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -44,9 +44,9 @@ namespace cpu { // it's added to the JIT. class SimpleOrcJIT { public: - using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; + using ObjLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer; using CompileFtor = std::function; - using CompileLayerT = llvm::orc::IRCompileLayer; + using CompileLayerT = llvm::orc::LegacyIRCompileLayer; using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4b129c95d46d8b5a119e5d23eef387daf7863cce..382dfd0d99df87bbadfe541ddaa32cd6da8e8068 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,7 +48,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 18ee25ba9158c28baaf01492c290638b9673f1ec..f8f5f392da8ab3348e63185aecf7b639daacaa42 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -50,7 +50,7 @@ class CpuEigenDotOperationTest /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(entry_computation)); CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index 00a7aa2ad2f6bac4877302296ccb76222557535c..e30f95311fce229f9c559d3bb40142151e8bf3e3 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -46,7 +46,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CompileAndVerifyIr(std::move(module), filecheck_pattern, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 1deb412064b02988a8d4a6d726969c948d354d47..04a81dfd35f459ff1fdb3181dc8fc65c62a37d4f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloVerifiedTestBase { +class CpuFusionTest : public HloTestBase { protected: CpuFusionTest() {} @@ -57,11 +57,11 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { builder.AddInstruction( HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -104,11 +104,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { builder.AddInstruction( HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -131,7 +131,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); Shape vshape = input_literal.shape(); @@ -183,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -250,12 +250,12 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); // Create computation and module. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -310,11 +310,11 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({negate1, negate2, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index a434c04a980b9b3cd849792b97a0d9e965ba09f2..9b10c49f4f547edfb2164f98c49cceb031148bdc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -91,7 +91,7 @@ TEST_P(CpuUnaryIntrinsicTest, DoIt) { /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); string check_lines{spec.check_lines.data(), spec.check_lines.size()}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 3b87683ffffefd2aa24dd234cc072425bef00a24..fa0e09ff6b5694c0e97963b83c6e541b858a1376 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -63,7 +63,7 @@ CHECK-NOT: private constant [48 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -104,14 +104,14 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [4 x i8] -CHECK: private constant [8 x i8] +CHECK-DAG: private constant [4 x i8] +CHECK-DAG: private constant [8 x i8] CHECK-NOT: private constant [4 x i8] CHECK-NOT: private constant [8 x i8] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); + ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index b35fd9dad877c319c3d0110c96a00aeefa78769e..a7702c2aeeaff8a46a2c4f2785ccb873ea2c08e5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -56,7 +56,7 @@ TEST_F(CpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 990ff94ba2338cb663b655ca3106bda83ab718a3..70008947f371d25e95d02839c30ba822fce7a292 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef..64fb50318394918b277fd717994f5366d762ac36 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -18,19 +18,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloVerifiedTestBase { +class DefuserTest : public HloTestBase { protected: // Returns the number of fusion instructions in the module. - int FusionCount() { + int FusionCount(const HloModule* m) { int count = 0; - for (HloComputation* computation : module().computations()) { + for (HloComputation* computation : m->computations()) { if (computation->IsFusionComputation()) { count++; } @@ -43,6 +43,7 @@ class DefuserTest : public HloVerifiedTestBase { }; TEST_F(DefuserTest, NoFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -51,13 +52,14 @@ TEST_F(DefuserTest, NoFusionInstruction) { builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - module().AddEntryComputation(builder.Build()); - EXPECT_EQ(0, FusionCount()); + m->AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount(m.get())); - EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(defuser_.Run(m.get()).ValueOrDie()); } TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -66,21 +68,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { auto add = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -91,21 +94,22 @@ TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add(op::Parameter(), op::Parameter()))); } TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -128,22 +132,23 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction( {add2, constant, div, mul, sub, negate, add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(1, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(1, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, MultipleFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -166,7 +171,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { auto add2 = builder.AddInstruction( HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add2, constant, div, mul}, HloInstruction::FusionKind::kLoop); computation->CreateFusionInstruction({sub, negate, add}, @@ -174,15 +179,16 @@ TEST_F(DefuserTest, MultipleFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Add(op::Constant(), op::Divide())); } TEST_F(DefuserTest, NestedFusionInstructions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); @@ -193,7 +199,7 @@ TEST_F(DefuserTest, NestedFusionInstructions) { auto negate = builder.AddInstruction( HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); auto outer_fusion = computation->CreateFusionInstruction( {negate, add}, HloInstruction::FusionKind::kLoop); HloInstruction* fused_negate = outer_fusion->fused_expression_root(); @@ -203,9 +209,9 @@ TEST_F(DefuserTest, NestedFusionInstructions) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); - EXPECT_EQ(2, FusionCount()); - EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); - EXPECT_EQ(0, FusionCount()); + EXPECT_EQ(2, FusionCount(m.get())); + EXPECT_TRUE(defuser_.Run(m.get()).ValueOrDie()); + EXPECT_EQ(0, FusionCount(m.get())); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); } diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index b3549acfc291a54b2345b006310613c3a45a4b47..ed37099a5428075928ec98b134632867d58bbfe7 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" #include "tensorflow/compiler/xla/service/defuser.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" namespace xla { @@ -45,6 +46,7 @@ class ControlDepRemover : public HloModulePass { Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index edbcb25247421cdb50a845df1ec8b1851970efe3..e1e3b156fb34fd128864ed34c6d9d055294672bf 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/numbers.h" namespace xla { @@ -39,6 +40,10 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( "Failed to allocate request for %s (%uB) on device ordinal %d", tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } + VLOG(3) << absl::StreamFormat( + "Allocated %s (%uB) on device ordinal %d: %p", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal, + result.opaque()); return OwningDeviceMemory(result, device_ordinal, this); } @@ -47,6 +52,8 @@ Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, if (!mem.is_null()) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); + VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d", + mem.opaque(), device_ordinal); stream_executor->Deallocate(&mem); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 3e7373adc5ab8a60fd18348ce2477175aaaa8fd4..c54f81e6915a286757e59821c2684a7271889816 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -50,7 +50,7 @@ void DfsHloVisitorBase::SetVisiting( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; DCHECK(NotVisited(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisiting); + visit_state_[instruction.unique_id()] = VisitState::kVisiting; } template @@ -58,7 +58,7 @@ void DfsHloVisitorBase::SetVisited( const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; DCHECK(NotVisited(instruction) || IsVisiting(instruction)); - visit_state_.SetState(instruction.unique_id(), VisitState::kVisited); + visit_state_[instruction.unique_id()] = VisitState::kVisited; } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 68d01d75a2ed3d7eaadb03a46ba3bd20f43a9ffc..e84bf00153aa28df29d8df486b92654feab4afbf 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" @@ -107,6 +108,7 @@ class DfsHloVisitorBase { virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -249,6 +251,7 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0; virtual Status HandleAfterAll(HloInstructionPtr token) = 0; // Invoked to inform the visitor that the traversal has completed, and that @@ -263,21 +266,25 @@ class DfsHloVisitorBase { kVisited = 2, }; - VisitState GetVisitState(int id) { return visit_state_.GetState(id); } + VisitState GetVisitState(int id) { + auto iter = visit_state_.find(id); + if (iter == visit_state_.end()) { + return VisitState::kNotVisited; + } + return iter->second; + } VisitState GetVisitState(const HloInstruction& instruction); // Resize internal state if necessary to hold state for ids <= num. // This call is purely a performance hint and can be omitted without // affecting correctness. - void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + void ReserveVisitStates(int num) { visit_state_.reserve(num); } // Useful when we want to visit the same computation more than once with the // same visitor. - void ResetVisitStates() { visit_state_.Reset(); } + void ResetVisitStates() { visit_state_.clear(); } - void SetVisitState(int id, VisitState state) { - visit_state_.SetState(id, state); - } + void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } // Sets the visitation state of the given instruction as kVisiting. // @@ -326,44 +333,7 @@ class DfsHloVisitorBase { virtual Status Postprocess(HloInstructionPtr hlo); private: - class DFSVisitStates { - public: - DFSVisitStates() {} - void Reserve(uint64 num) { - states_.reserve((num + kStatesPerWord - 1) / kStatesPerWord); - } - VisitState GetState(uint64 id) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - return VisitState::kNotVisited; - } - static_assert(static_cast(VisitState::kVisited) < 3, - "VisitState must fit in two bits"); - uint64 w = states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - return static_cast((w >> shift) & 0x3); - } - void SetState(uint64 id, VisitState state) { - uint64 word_index = id / kStatesPerWord; - if (word_index >= states_.size()) { - states_.resize(word_index + 1, 0); - } - uint64* w = &states_[word_index]; - uint32 shift = 2 * (id % kStatesPerWord); // 2 bits per state - uint64 mask = 0x3ull << shift; - *w = (*w & ~mask) | (static_cast(state) << shift); - DCHECK_EQ(GetState(id), state); - } - void Reset() { states_.clear(); } - - private: - static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; - // Map from id to two-bit states. We store 32 such states per 64-bit - // value - std::vector states_; - }; - - DFSVisitStates visit_state_; + absl::flat_hash_map visit_state_; TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); }; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 4cd10ab06cd3b804406607212d3f3c316d6cff95..80ea5be298aea44a0f424398da74c4e478f10346 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -203,6 +203,12 @@ class DfsHloVisitorWithDefaultBase Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } + Status HandleGetDimensionSize(HloInstructionPtr get_size) override { + return DefaultAction(get_size); + } + Status HandleAddDependency(HloInstructionPtr add_dependency) override { + return DefaultAction(add_dependency); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d0472689bf48092ceef2e9792c1358687d707ec --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -0,0 +1,459 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +namespace { +bool IsTrivialWindowDimension(const WindowDimension& window_dimension) { + return window_dimension.size() == 1 && window_dimension.stride() == 1 && + window_dimension.padding_low() == 0 && + window_dimension.padding_high() == 0 && + window_dimension.window_dilation() == 1 && + window_dimension.base_dilation() == 1; +} +} // namespace + +class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { + public: + explicit DynamicDimensionInferenceVisitor( + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) + : param_bindings_(param_bindings), parent_(parent) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static Status Run(HloComputation* computation, + const DynamicParameterBinding& param_bindings, + DynamicDimensionInference* parent) { + DynamicDimensionInferenceVisitor visitor(param_bindings, parent); + return computation->Accept(&visitor); + } + + Status HandleParameter(HloInstruction* hlo) override; + + Status HandleReduce(HloInstruction* hlo) override; + + Status HandleDot(HloInstruction* hlo) override; + + Status HandleTranspose(HloInstruction* hlo) override; + + Status HandleReshape(HloInstruction* hlo) override; + + Status HandlePad(HloInstruction* hlo) override; + + Status HandleBroadcast(HloInstruction* hlo) override; + + Status HandleGetDimensionSize(HloInstruction* hlo) override; + + Status HandleSelect(HloInstruction* hlo) override; + + Status HandleConvolution(HloInstruction* hlo) override; + + Status HandleReduceWindow(HloInstruction* hlo) override; + + Status HandleSelectAndScatter(HloInstruction* hlo) override; + + Status HandleGetTupleElement(HloInstruction* hlo) override; + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + + Status HandleElementwiseBinary(HloInstruction* hlo) override; + + private: + using OperandDynamicDimensionFn = std::function; + + Status ForEachOperandDynamicDimension(HloInstruction* inst, + const OperandDynamicDimensionFn&); + + // Pass through a dynamic dimension from the input to the output with the same + // value and index in the shape. This is a helper function to handle trivial + // instructions like elementwise operations. + Status PassThroughDynamicDimension(HloInstruction*); + + // The dynamic parameter bindings of this computation. + const DynamicParameterBinding& param_bindings_; + + // A pointer to DynamicDimensionInference, used to update the dynamic mapping. + DynamicDimensionInference* parent_; +}; + +Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + return UnimplementedStrCat( + "Asked to propagate a dynamic dimension from hlo ", + operand->ToString(), "@", index.ToString(), "@", dimension, + " to hlo ", hlo->ToString(), ", which is not implemented."); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (hlo->tuple_index() == index[0]) { + ShapeIndex new_index = + ShapeIndexView(index).ConsumeFront().ToShapeIndex(); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); + } + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + int64 broadcast_dim = hlo->dimensions(dimension); + parent_->SetDynamicSize(hlo, index, broadcast_dim, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + if (operand_index != 0) { + return Unimplemented( + "Dynamic dimension on padding value is not supported"); + } + const PaddingConfig_PaddingConfigDimension& padding_config = + hlo->padding_config().dimensions(dimension); + if (padding_config.interior_padding() == 0 && + padding_config.edge_padding_low() == 0 && + padding_config.edge_padding_high() == 0) { + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); + return Status::OK(); + } else { + return Unimplemented( + "Dynamic dimension propagation on padding dimension is not " + "supported."); + } + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce = hlo; + int64 operand_count = reduce->operand_count(); + CHECK_EQ(operand_count % 2, 0); + if (operand_index >= operand_count / 2) { + // Init values doesn't have dynamic size. + return Status::OK(); + } + if ((absl::c_count(reduce->dimensions(), dimension) != 0)) { + // Dimension is to be reduce, stop tracing. + return Status::OK(); + } + + // Find out the new dynamic dimension after reduce. + int64 dimensions_not_reduced_count = 0; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { + if (dimension == i) { + parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, + dynamic_size); + + return Status::OK(); + } + if (absl::c_count(reduce->dimensions(), i) == 0) { + dimensions_not_reduced_count++; + } + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* dot = hlo; + const DotDimensionNumbers& dimension_numbers = + dot->dot_dimension_numbers(); + // A map from the operand dimensions to result dimension. + absl::flat_hash_map result_dim_mapping; + int64 current_result_dims = 0; + std::unordered_set batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + + for (int64 i : dimension_numbers.rhs_batch_dimensions()) { + result_dim_mapping[i] = current_result_dims++; + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(0)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.lhs_contracting_dimensions(), i)) { + if (operand_index == 0) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + for (int64 i = 0; i < ShapeUtil::Rank(dot->operand(1)->shape()); i++) { + if (!absl::c_linear_search( + dimension_numbers.rhs_contracting_dimensions(), i) && + !absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), + i)) { + if (operand_index == 1) { + result_dim_mapping[i] = current_result_dims; + } + current_result_dims++; + } + } + + // Check if the operand dim is in the result shape. If so, add another + // work item to trace that dimension. + auto iter = result_dim_mapping.find(dimension); + if (iter != result_dim_mapping.end()) { + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); + } + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, {}, hlo->dimensions()[dimension], + dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleConvolution( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* conv = hlo; + const ConvolutionDimensionNumbers& dimension_numbers = + conv->convolution_dimension_numbers(); + + if (operand_index == 0) { + if (dimension == dimension_numbers.input_batch_dimension()) { + parent_->SetDynamicSize(conv, {}, + dimension_numbers.output_batch_dimension(), + dynamic_size); + return Status::OK(); + } + + if (dimension == dimension_numbers.input_feature_dimension()) { + return Status::OK(); + } + } else { + if (dimension == dimension_numbers.kernel_input_feature_dimension()) { + return Status::OK(); + } + } + + return Unimplemented("Dynamic Spatial Convolution is not supported: %s", + conv->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleGetDimensionSize( + HloInstruction*) { + // Dynamic dimension doesn't propagate through GetDimensionSize: + // + // Input: F32[x, y, z] + // | + // GetDimensionSize(1): U32[] + // + // The returned value is a scalar, which doesn't have any dynamic dimension in + // the shape (although the value contains the real size of the dynamic + // dimension of the input). + return Status::OK(); +} + +Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseUnary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleSelect(HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleElementwiseBinary( + HloInstruction* hlo) { + return PassThroughDynamicDimension(hlo); +} + +Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reshape = hlo; + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(operand->shape(), + reshape->shape()); + for (auto& unmodified : unmodified_dims) { + if (unmodified.first == dimension) { + parent_->SetDynamicSize(reshape, {}, unmodified.second, + dynamic_size); + return Status::OK(); + } + } + return Unimplemented( + "Dynamic Reshape on modified dimensions is yet not supported: %s", + reshape->ToString()); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleReduceWindow( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* reduce_window = hlo; + const WindowDimension& window_dimension = + reduce_window->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial reduce window is not supported: %s", + reduce_window->ToString()); + } + + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( + HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, + int64 operand_index, HloInstruction* dynamic_size) { + HloInstruction* select_and_scatter = hlo; + const WindowDimension& window_dimension = + select_and_scatter->window().dimensions(dimension); + + if (!IsTrivialWindowDimension(window_dimension)) { + return Unimplemented( + "Dynamic Spatial select and scatter is not supported: %s", + select_and_scatter->ToString()); + } + + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); + + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { + return param_bindings_.ForEachBinding( + [&](const DynamicParameterBinding::DynamicParameter& dynamic_parameter, + const DynamicParameterBinding::DynamicDimension& dynamic_dimension) { + if (dynamic_dimension.parameter_num != hlo->parameter_number()) { + return Status::OK(); + } + HloComputation* computation = hlo->parent(); + HloInstruction* target_parameter = + computation->parameter_instruction(dynamic_dimension.parameter_num); + + HloInstruction* dynamic_size = + computation->parameter_instruction(dynamic_parameter.parameter_num); + for (int64 i : dynamic_parameter.parameter_index) { + dynamic_size = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(dynamic_size->shape(), {i}), + dynamic_size, i)); + } + + parent_->SetDynamicSize(target_parameter, + dynamic_dimension.parameter_index, + dynamic_dimension.dimension, dynamic_size); + return Status::OK(); + }); +} + +Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( + HloInstruction* inst, const OperandDynamicDimensionFn& fn) { + for (int64 operand_index = 0; operand_index < inst->operand_count(); + ++operand_index) { + auto iter = + parent_->per_hlo_dynamic_dimensions_.find(inst->operand(operand_index)); + if (iter != parent_->per_hlo_dynamic_dimensions_.end()) { + for (auto& dynamic_dimension : iter->second) { + HloInstruction* dynamic_size = parent_->GetDynamicSize( + dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim); + TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, + dynamic_dimension.dim, operand_index, + dynamic_size)); + } + } + } + return Status::OK(); +} + +/* static */ +StatusOr DynamicDimensionInference::Run( + HloModule* module) { + VLOG(0) << "Param Config " << module->dynamic_parameter_binding().ToString(); + DynamicDimensionInference inference(module); + TF_RETURN_IF_ERROR(inference.AnalyzeDynamicDimensions()); + return inference; +} + +DynamicDimensionInference::DynamicDimensionInference(HloModule* module) + : module_(module) {} + +Status DynamicDimensionInference::AnalyzeDynamicDimensions() { + return DynamicDimensionInferenceVisitor::Run( + module_->entry_computation(), module_->dynamic_parameter_binding(), this); +} + +HloInstruction* DynamicDimensionInference::GetDynamicSize( + HloInstruction* inst, const ShapeIndex& index, int64 dim) const { + auto iter = dynamic_mapping_.find(DynamicDimension{inst, index, dim}); + if (iter != dynamic_mapping_.end()) { + return iter->second; + } + return nullptr; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..164d15bf111a92e3da957f609b54ee0662ef18b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// DynamicDimensionInference analyzes each HLO instruction in a graph and +// inferences which dimensions are dynamic and which scalar instructions +// represent the runtime real size of those dynamic dimensions. +class DynamicDimensionInference { + public: + static StatusOr Run(HloModule* module); + + string ToString() const; + + // If the dimension `dim` of instruction `inst` at `index` has a dynamic size, + // returns a scalar HloInstruction that represents the runtime size of that + // dimension. Otherwise returns nullptr. + HloInstruction* GetDynamicSize(HloInstruction* inst, const ShapeIndex& index, + int64 dim) const; + + friend class DynamicDimensionInferenceVisitor; + + private: + explicit DynamicDimensionInference(HloModule* module); + + // DynamicDimension is used as a key in the dynamic key-value mapping. It + // unambiguously represents a dynamic dimension of a instruction at a given + // index. + struct DynamicDimension { + // HloInstruction that holds the dimension. + HloInstruction* inst; + // Subshape of the instruction that holds the dimension. + ShapeIndex index; + // The dimension number of the dynamic dimension at given index of a given + // instruction. + int64 dim; + + // Artifacts needed to make this struct able to be used as a `key` in absl + // maps. "friend" keywords are added so these functions can be found through + // ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.inst, m.index, m.dim); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.inst == rhs.inst && lhs.index == rhs.index && + lhs.dim == rhs.dim; + } + }; + + // Update the dynamic mapping so that we know dimension `dim` of instruction + // `inst` at `index` has a dynamic size, and its runtime size is represented + // by a scalar instruction `size`. + void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, + HloInstruction* size) { + dynamic_mapping_.try_emplace(DynamicDimension{inst, index, dim}, size); + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(DynamicDimension{inst, index, dim}); + } + + // AnalyzeDynamicDimensions starts the analysis of the dynamic dimensions in + // module_. + Status AnalyzeDynamicDimensions(); + + // HloModule being analyzed. + HloModule* module_; + + // dynamic_mapping_ holds the result of the analysis. It maps a dynamic + // dimension to a scalar HloInstruction that represents the real dynamic size + // of the dynamic dimension. + using DynamicMapping = absl::flat_hash_map; + DynamicMapping dynamic_mapping_; + + using PerHloDynamicDimensions = + absl::flat_hash_map>; + PerHloDynamicDimensions per_hlo_dynamic_dimensions_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_DIMENSION_INFERENCE_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea9ebed45d99797ce4f80376ec3d0b758da3ca17 --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -0,0 +1,535 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DynamicDimensionInferenceTest : public HloTestBase { + protected: + DynamicDimensionInferenceTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); + } + + Status RunInference() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); + TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, + DynamicDimensionInference::Run(module_.get())); + + inference_ = absl::make_unique(inference); + return Status::OK(); + } + + HloComputation* GetAdd() { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + return module_->AddEmbeddedComputation(embedded_builder.Build()); + } + + std::unique_ptr module_; + std::unique_ptr inference_; + const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {}); +}; + +TEST_F(DynamicDimensionInferenceTest, ParamTest) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "param")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2); + EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, GetTupleElement) { + // When data flows through GTE, the dynamic dimension size keeps the + // same, and the index has its front popped. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param")); + + auto gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(input_shape, param, 0)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{0, {1}}, + DynamicParameterBinding::DynamicDimension{0, {0}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1), + op::GetTupleElement(param, 1)); + + EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) { + // When data flows through elementwise, the dynamic dimension size keeps the + // same. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto* negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestI) { + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + reduce_shape, negate, init, {0, 2}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceTestII) { + // Same as ReduceTestI, but only reduce one dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param)); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto reduce = builder.AddInstruction( + HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, DotTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) { + auto builder = HloComputation::Builder(TestName()); + constexpr int xdim = 3; + constexpr int ydim = 2; + constexpr int zdim = 1; + auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim}); + auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim}); + auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, xy_shape, "A")); + auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, yz_shape, "B")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0); + + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(1); + dnums.set_output_feature_dimension(0); + + Window window; + + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + zx_shape, a_param, b_param, /*feature_group_count=*/1, window, dnums, + HloTestBase::DefaultPrecisionConfig(2))); + + module_->AddEntryComputation(builder.Build()); + + // Set up dynamic parameter binding for non-contracting dimension. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + // Set up binding for contracting dimensions. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, TransposeTest) { + // Test the ability to trace unmodified dimensions + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/2, scalar_shape_, "size_param")); + auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/3, scalar_shape_, "size_param")); + + auto* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{2, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{3, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2); + EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTest) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* reshape = builder.AddInstruction( + HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 2})); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 3})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) { + // Test the ability to trace unmodified reshape dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6}); + auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param)); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 1})); + + Status status = RunInference(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(DynamicDimensionInferenceTest, BroadcastTest) { + // Test the ability to trace broadcast dimension. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2}); + auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4}); + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(output_shape, a_param, {1})); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param); + EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr); +} + +TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) { + // Test the ability to trace reduce window batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) { + // Test the ability to trace select and scatter batch dimensions. + auto builder = HloComputation::Builder(TestName()); + auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4}); + auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + + Window window; + // First dimension is unchanged. + WindowDimension* batch_dim = window.add_dimensions(); + batch_dim->set_size(1); + batch_dim->set_stride(1); + batch_dim->set_padding_low(0); + batch_dim->set_padding_high(0); + batch_dim->set_window_dilation(1); + batch_dim->set_base_dilation(1); + + // Second and third dimension are reduced. + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(2); + dim->set_stride(2); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + + auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, input_shape, "A")); + auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, scalar_shape_, "size_param")); + + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + + auto* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, a_param, init, window, GetAdd())); + + module_->AddEntryComputation(builder.Build()); + + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc new file mode 100644 index 0000000000000000000000000000000000000000..c8bfc8905064bcd7b68fe259fbcc1546ff083dbd --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +Status DynamicParameterBinding::Bind( + const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) { + auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter); + TF_RET_CHECK(result.second); + return Status::OK(); +} + +absl::optional +DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) { + auto param_iter = bindings_.find(dynamic_dimension); + if (param_iter == bindings_.end()) { + return absl::nullopt; + } + return param_iter->second; +} + +DynamicParameterBindingProto DynamicParameterBinding::ToProto() const { + DynamicParameterBindingProto result; + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + DynamicParameterBindingProto::Binding binding_proto; + binding_proto.set_dynamic_param_num(dynamic_param.parameter_num); + for (int64 i : dynamic_param.parameter_index) { + binding_proto.add_dynamic_param_index(i); + } + + binding_proto.set_target_param_num(dynamic_dimension.parameter_num); + + for (int64 i : dynamic_dimension.parameter_index) { + binding_proto.add_target_param_index(i); + } + + binding_proto.set_target_param_dim_num(dynamic_dimension.dimension); + result.add_entries()->Swap(&binding_proto); + } + return result; +} + +StatusOr DynamicParameterBinding::CreateFromProto( + const DynamicParameterBindingProto& proto) { + DynamicParameterBinding result; + for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) { + int64 dynamic_param_num = binding.dynamic_param_num(); + ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(), + binding.dynamic_param_index().end()); + int64 target_param_num = binding.target_param_num(); + ShapeIndex target_param_index(binding.target_param_index().begin(), + binding.target_param_index().end()); + int64 target_dim_num = binding.target_param_num(); + + TF_RETURN_IF_ERROR( + result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index}, + DynamicDimension{target_param_num, target_param_index, + target_dim_num})); + } + + return result; +} + +string DynamicParameterBinding::ToString() const { + std::vector pieces; + pieces.push_back("DynamicParameterBinding: "); + for (const auto& binding : bindings_) { + const DynamicDimension& dynamic_dimension = binding.first; + const DynamicParameter& dynamic_param = binding.second; + pieces.push_back(absl::StrFormat( + " -- Input param number %lld at %s has dim %lld as dynamic" + " dimension, which is represented by param number %lld at " + "%s", + dynamic_dimension.parameter_num, + dynamic_dimension.parameter_index.ToString(), + dynamic_dimension.dimension, dynamic_param.parameter_num, + dynamic_param.parameter_index.ToString())); + } + return absl::StrJoin(pieces, "\n"); +} + +Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const { + for (const auto& binding : bindings_) { + TF_RETURN_IF_ERROR(fn(binding.second, binding.first)); + } + return Status::OK(); +} + +Status DynamicParameterBinding::Verify(const HloModule& module) const { + const HloComputation* entry = module.entry_computation(); + return ForEachBinding([&](const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension) + -> Status { + TF_RET_CHECK(dynamic_parameter.parameter_num < entry->num_parameters()); + TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters()); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(), + dynamic_parameter.parameter_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid( + entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(), + dynamic_dimension.parameter_index)); + TF_RET_CHECK( + dynamic_dimension.dimension < + ShapeUtil::Rank(ShapeUtil::GetSubshape( + entry->parameter_instruction(dynamic_dimension.parameter_num) + ->shape(), + dynamic_dimension.parameter_index))); + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding) { + out << binding.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding.h b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..dd474d8eed1b2c30ddb8f624a864198c74eacaba --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding.h @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; +// We currently use an explicit API that takes an extra parameter to indicate +// the runtime size of a dynamic dimension. DynamicParameterBinding indicates +// the relationship between parameter: We can have a dynamic parameter that +// points to another target parameter to indicate that the target parameter is +// dynamic. +// +// +// TODO(b/119520625): Remove this API once we have more dynamic shape infra +// ready. +class DynamicParameterBinding { + public: + // DynamicParameter represents a special parameter that is used to represent + // the runtime size of a dimension of another parameter. A dynamic parameter + // has to be a scalar value. + struct DynamicParameter { + // The parameter number of dynamic parameter. + int64 parameter_num; + // The index of the parameter. + ShapeIndex parameter_index; + }; + + // DynamicDimension represents a dimension whose size is determined at + // runtime. A DynamicDimension's runtime size is determined by the binded + // DynamicParameter using `DynamicParameterBinding::Bind` method. + struct DynamicDimension { + // The parameter number of dynamic dimension. + int64 parameter_num; + // The subshape index of the parameter. + ShapeIndex parameter_index; + // The dimension number in the subshape. + int64 dimension; + + // "friend" keyword are added so these functions can be found by ADL. + template + friend H AbslHashValue(H h, const DynamicDimension& m) { + return H::combine(std::move(h), m.parameter_num, m.parameter_index, + m.dimension); + } + + friend bool operator==(const DynamicDimension& lhs, + const DynamicDimension& rhs) { + return lhs.parameter_num == rhs.parameter_num && + lhs.parameter_index == rhs.parameter_index && + lhs.dimension == rhs.dimension; + } + }; + + DynamicParameterBinding() = default; + + virtual ~DynamicParameterBinding() = default; + + // Adds binding which indicates that the dimension indicated by + // `dynamic_dimension` is dynamic, and its runtime size is represented by + // `dynamic_parameter`. + Status Bind(const DynamicParameter& dynamic_parameter, + const DynamicDimension& dynamic_dimension); + + // Returns the parameter and the index representing the runtime size of + // dimension `dim_num` of parameter `param_num` at `param_index`. + // + // Returns nullopt if the binding is not set. + absl::optional GetBinding( + const DynamicDimension& dynamic_dimension); + + using BindingFn = + std::function; + + // Iterate through each binding. + Status ForEachBinding(BindingFn fn) const; + + DynamicParameterBindingProto ToProto() const; + + static StatusOr CreateFromProto( + const DynamicParameterBindingProto& proto); + + string ToString() const; + + // Verifies that the given binding is valid for the given module. + // Specifically, the binding's parameter and parameter size should be valid. + Status Verify(const HloModule& module) const; + + private: + // Keeps track of mappings from DynamicDimension to DynamicParameter. The + // direction of is chosen so that we can easily query if a dimension is + // dynamic and which dynamic parameter represents the real size of that + // dimension. + absl::flat_hash_map bindings_; +}; + +std::ostream& operator<<(std::ostream& out, + const DynamicParameterBinding& binding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DYNAMIC_PARAMETER_BINDING_H_ diff --git a/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..83a6d83dffde7995bd8e43917d13c5fd2705ba6f --- /dev/null +++ b/tensorflow/compiler/xla/service/dynamic_parameter_binding_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class DynamicParameterBindingTest : public HloTestBase {}; + +TEST_F(DynamicParameterBindingTest, SimpleBinding) { + // 'b' is a dynamic shape; 'a' represents the real size of b's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[10] parameter(1) + ROOT root = (f32[], f32[10]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}}, + DynamicParameterBinding::DynamicDimension{1, {}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1, + /*parameter_index=*/{}, + /*dimension=*/0}); + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBinding) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first + // dimension. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + TF_EXPECT_OK(binding.Verify(*module)); +} + +TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) { + // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both + // dimensions. + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[10, 10]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[10, 10] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + DynamicParameterBinding binding; + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 0})); + + TF_EXPECT_OK( + binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}}, + DynamicParameterBinding::DynamicDimension{0, {1}, 1})); + + absl::optional param = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + + EXPECT_TRUE(param); + EXPECT_EQ(param->parameter_num, 0); + EXPECT_EQ(param->parameter_index, ShapeIndex({0})); + + absl::optional param2 = + binding.GetBinding( + DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0, + /*parameter_index=*/{1}, + /*dimension=*/0}); + EXPECT_TRUE(param2); + EXPECT_EQ(param2->parameter_num, 0); + EXPECT_EQ(param2->parameter_index, ShapeIndex({0})); + + TF_EXPECT_OK(binding.Verify(*module)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 515267edd7caf42e04ebe638b99006db8967ea30..6f1f95f2e9082649b6ca9cc0da5c238e15b77c10 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -1671,26 +1672,66 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( b_->SetInsertPoint(init_block); + // Assign a unique id for each *different* operand, and count how often each + // operand is used. If all operands are different, the usage count will be 1 + // for each operand. + absl::flat_hash_map to_unique_operand_id; + std::vector operand_usage_count; + for (const auto* operand : hlo->operands()) { + if (to_unique_operand_id.contains(operand)) { + ++operand_usage_count[to_unique_operand_id[operand]]; + } else { + int64 unique_operand_id = to_unique_operand_id.size(); + to_unique_operand_id[operand] = unique_operand_id; + operand_usage_count.push_back(1); + } + } + + // To avoid that we emit the same operand more than once, we create one basic + // block for each *different* operand with a PHI node for the different source + // index inputs. + std::vector emit_operand_blocks( + to_unique_operand_id.size(), nullptr); + std::vector source_index_phis(to_unique_operand_id.size(), + nullptr); + for (const auto* operand : hlo->operands()) { + int64 operand_id = to_unique_operand_id[operand]; + if (emit_operand_blocks[operand_id] != nullptr) { + continue; + } + + emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand_id", operand_id), b_); + auto saved_insert_point = b_->GetInsertPoint(); + llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_); + source_index_phis[operand_id] = + PHI(source_index.GetType(), operand_usage_count[operand_id]); + auto operand_index = source_index; + operand_index[concat_dim] = source_index_phis[operand_id]; + + // Create the terminator of the block before calling operand generators, + // because they require non-degenerate basic blocks. + b_->SetInsertPoint(llvm::BranchInst::Create( + exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id])); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(operand_index)); + output->addIncoming(value, b_->GetInsertBlock()); + b_->SetInsertPoint(init_block, saved_insert_point); + } + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), b_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, - false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - b_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, b_->GetInsertBlock()); + int64 operand_id = to_unique_operand_id[operand]; + source_index_phis[operand_id]->addIncoming(source_index[concat_dim], + b_->GetInsertBlock()); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), + emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. @@ -1815,8 +1856,6 @@ StatusOr ElementalIrEmitter::EmitElementalGather( // Clamp the gather index so that the gather region fits in the operand. // gather_dim_component_extended_inbound = // clamp(gather_dim_component_extended, 0, largest_valid_start_index); - - // TODO(b/111078873): This is implementation defined behavior. bool is_signed = ShapeUtil::ElementIsSigned(indices_shape); auto gather_dim_component_extended_inbound = EmitIntegralMin( index.GetConstantWithIndexType(largest_valid_start_index), @@ -2206,13 +2245,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( : iota->shape(); PrimitiveType component_element_type = component_shape.element_type(); llvm::Value* iota_result; - if (ShapeUtil::ElementIsIntegral(component_shape)) { + if (primitive_util::IsIntegralType(component_element_type) || + component_element_type == PRED) { iota_result = b_->CreateIntCast( elem_index_linear, llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), /*isSigned=*/false); } else { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + TF_RET_CHECK( + primitive_util::IsFloatingPointType(component_element_type)) << component_element_type; llvm::Type* float_ir_type; if (component_element_type == BF16) { diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb..10b8c01ff1383658fcfb2271c177ba54347f985a 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 3a6780f2a67f230cae626ea00cfbf93b4e60d968..b34bca55a48b113c325dbf28c03f7a0f5b71f658 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -61,7 +61,7 @@ struct ExecutionOutput { class Executable { public: explicit Executable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), @@ -162,7 +162,7 @@ class Executable { return hlo_profile_printer_data_ != nullptr; } - const HloModule& module() const { return *hlo_module_; } + HloModule& module() const { return *hlo_module_; } const bool has_module() const { return hlo_module_ != nullptr; } @@ -199,7 +199,7 @@ class Executable { // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. - const std::unique_ptr hlo_module_; + const std::unique_ptr hlo_module_; // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8..8eeb930b48165a2e3c622581e05cb5f7063fa1fa 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloVerifiedTestBase { +class FlattenCallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -108,7 +108,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // c // // Calls are made via kCall, kWhile, and kMap instructions. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation = module->AddEmbeddedComputation(MakeConditionComputation()); HloComputation* c_computation = @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module); + std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -149,7 +149,7 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { // Test corner case of a computation used as a body and a loop condition. TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* cond_computation; { HloComputation::Builder builder(TestName() + ".cond"); @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -201,7 +201,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { // C // TEST_F(FlattenCallGraphTest, FlattenCalls) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* c_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -224,7 +224,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { } TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* sub_computation = module->AddEmbeddedComputation(MakeScalarComputation()); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module.get()); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index cb86c9857936f21d9d2ac6bc22c725b89cca6482..01cef499665c050d4453382289168276028e1d26 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -346,7 +346,8 @@ StatusOr GatherExpander::ExpandGather( [&](HloInstruction* indvar, const std::vector& loop_state) { return GatherLoopBody(*gather_instr, indvar, loop_state); - }); + }, + gather_instr->metadata()); TF_ASSIGN_OR_RETURN(std::vector gather_loop_result, gather_loop_result_or_error); diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 2b39359aae9fc01f1a88a2594108b2772788e826..8af9c6b71fbc391bf7c0e9809e979b65135a6df3 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -28,7 +28,7 @@ class GatherExpander : public HloModulePass { absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; - private: + protected: StatusOr ExpandGather(HloInstruction* gather_instr); }; diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 141dd4d6f10272ce749edc4e91153c365ed322e6..a3102368cb1dba15da7422337666d278cef775ab 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -104,5 +104,44 @@ ENTRY main { ShapeUtil::MakeShape(S32, {2, 3}), ShapeUtil::GetTupleElementShape(while_shape, 3))); } + +TEST(GatherExpanderTest, CheckOpMetadata) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, + index_vector_dim=1, + slice_sizes={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + OpMetadata metadata; + metadata.set_op_name("Gather"); + module->entry_computation()->root_instruction()->set_metadata(metadata); + TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + ASSERT_TRUE(changed); + + HloInstruction* while_instr = nullptr; + for (auto* instr : module->entry_computation()->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + while_instr = instr; + } + } + + ASSERT_NE(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + EXPECT_EQ(while_instr->metadata().op_name(), "Gather"); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 522e9f5948da2206f144ede4fdd95350474146d9..bfd1b6cb1492f5cb709e2ecefe73782094e26f5e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -25,6 +25,10 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) xla_proto_library( name = "backend_configs", @@ -107,7 +111,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -154,7 +157,7 @@ cc_library( deps = [ ":backend_configs", ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":elemental_ir_emitter", ":gpu_constants", ":gpu_executable", @@ -323,7 +326,7 @@ cc_library( ], deps = [ ":buffer_allocations", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", @@ -385,13 +388,13 @@ cc_library( ) cc_library( - name = "cudnn_convolution_algorithm_picker", - srcs = ["cudnn_convolution_algorithm_picker.cc"], - hdrs = ["cudnn_convolution_algorithm_picker.h"], + name = "cudnn_conv_algorithm_picker", + srcs = ["cudnn_conv_algorithm_picker.cc"], + hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ ":backend_configs", ":buffer_comparator", - ":cudnn_convolution_runner", + ":cudnn_conv_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -404,14 +407,15 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) cc_library( - name = "cudnn_convolution_runner", - srcs = ["cudnn_convolution_runner.cc"], - hdrs = ["cudnn_convolution_runner.h"], + name = "cudnn_conv_runner", + srcs = ["cudnn_conv_runner.cc"], + hdrs = ["cudnn_conv_runner.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -431,9 +435,9 @@ cc_library( ) cc_library( - name = "cudnn_convolution_rewriter", - srcs = ["cudnn_convolution_rewriter.cc"], - hdrs = ["cudnn_convolution_rewriter.h"], + name = "cudnn_conv_rewriter", + srcs = ["cudnn_conv_rewriter.cc"], + hdrs = ["cudnn_conv_rewriter.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -448,17 +452,17 @@ cc_library( ) tf_cc_test( - name = "cudnn_convolution_rewriter_test", - srcs = ["cudnn_convolution_rewriter_test.cc"], + name = "cudnn_conv_rewriter_test", + srcs = ["cudnn_conv_rewriter_test.cc"], deps = [ - ":cudnn_convolution_rewriter", + ":cudnn_conv_rewriter", ":ir_emission_utils", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -580,9 +584,9 @@ tf_cc_test( ) cc_library( - name = "pad_insertion", - srcs = ["pad_insertion.cc"], - hdrs = ["pad_insertion.h"], + name = "cudnn_conv_padding_legalization", + srcs = ["cudnn_conv_padding_legalization.cc"], + hdrs = ["cudnn_conv_padding_legalization.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal", @@ -590,6 +594,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", @@ -598,9 +603,9 @@ cc_library( ) cc_library( - name = "pad_for_tensor_cores", - srcs = ["pad_for_tensor_cores.cc"], - hdrs = ["pad_for_tensor_cores.h"], + name = "cudnn_conv_pad_for_tensor_cores", + srcs = ["cudnn_conv_pad_for_tensor_cores.cc"], + hdrs = ["cudnn_conv_pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -612,16 +617,16 @@ cc_library( ) tf_cc_test( - name = "pad_for_tensor_cores_test", - srcs = ["pad_for_tensor_cores_test.cc"], + name = "cudnn_conv_pad_for_tensor_cores_test", + srcs = ["cudnn_conv_pad_for_tensor_cores_test.cc"], deps = [ + ":cudnn_conv_pad_for_tensor_cores", ":ir_emission_utils", - ":pad_for_tensor_cores", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -658,9 +663,11 @@ cc_library( srcs = ["nvptx_compiler.cc"], hdrs = ["nvptx_compiler.h"], deps = [ - ":cudnn_convolution_algorithm_picker", - ":cudnn_convolution_rewriter", - ":cudnn_fused_convolution_rewriter", + ":cudnn_conv_algorithm_picker", + ":cudnn_conv_pad_for_tensor_cores", + ":cudnn_conv_padding_legalization", + ":cudnn_conv_rewriter", + ":cudnn_fused_conv_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -672,11 +679,10 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":multi_output_fusion", - ":pad_for_tensor_cores", - ":pad_insertion", ":partition_assignment", ":stream_assignment", ":stream_executor_util", + ":variadic_op_splitter", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -695,6 +701,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", + "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", @@ -704,7 +711,6 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -780,7 +786,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -844,7 +849,6 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -881,16 +885,6 @@ cc_library( ], ) -cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], @@ -914,7 +908,6 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -976,9 +969,9 @@ tf_cc_test( ) cc_library( - name = "cudnn_fused_convolution_rewriter", - srcs = ["cudnn_fused_convolution_rewriter.cc"], - hdrs = ["cudnn_fused_convolution_rewriter.h"], + name = "cudnn_fused_conv_rewriter", + srcs = ["cudnn_fused_conv_rewriter.cc"], + hdrs = ["cudnn_fused_conv_rewriter.h"], deps = [ ":backend_configs", ":ir_emission_utils", @@ -990,3 +983,57 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", ], ) + +tf_cc_test( + name = "cudnn_fused_conv_rewriter_test", + srcs = ["cudnn_fused_conv_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "variadic_op_splitter", + srcs = ["variadic_op_splitter.cc"], + hdrs = ["variadic_op_splitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "variadic_op_splitter_test", + srcs = ["variadic_op_splitter_test.cc"], + deps = [ + ":ir_emission_utils", + ":variadic_op_splitter", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 4effea637d01bf23b54d341b77306b20b1b133c8..e1dffad3045808c4f316ccafdda39a174e1560c8 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" @@ -56,9 +56,9 @@ Status ConvolutionThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(scratch_buffer_); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_, - absl::MakeSpan(operand_se_buffers), - result_buffer, scratch, stream)); + TF_RETURN_IF_ERROR(RunCudnnConv(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory tuple_addr( diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index f53bc541983378819dba36489dd69c348f50af32..c71515490c94ef54baad9005509d1813de630159 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d6780fa1c7b0c636eb771c40e74f074cd8c4c4b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -0,0 +1,407 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { +namespace gpu { +namespace { + +using absl::optional; +using se::DeviceMemoryBase; +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); + total_allocated_bytes_ += byte_size; + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); +} + +std::vector GetAlgorithms(CudnnConvKind kind, + se::StreamExecutor* stream_exec) { + std::vector algorithms; + bool succ = false; + switch (kind) { + case CudnnConvKind::kBackwardFilter: + succ = + stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); + break; + case CudnnConvKind::kBackwardInput: + succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); + break; + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); + break; + } + DCHECK(succ); + + return algorithms; +} + +string AlgorithmToString(const AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return absl::StrCat(algo.algo_id(), "+TC"); + } + return absl::StrCat(algo.algo_id()); +} + +string NumBytesToString(int64 bytes) { + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); +} + +// Acquires a process-global lock on the device pointed to by the given +// StreamExecutor. +// +// This is used to prevent other XLA instances from trying to autotune on this +// device while we're using it. +tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + // se::Platform*s are global singletons guaranteed to live forever. + static auto* mutexes = + new std::map, + tensorflow::mutex>(); + + tensorflow::mutex_lock global_lock(mu); + auto it = mutexes + ->emplace(std::piecewise_construct, + std::make_tuple(stream_exec->platform(), + stream_exec->device_ordinal()), + std::make_tuple()) + .first; + return tensorflow::mutex_lock{it->second}; +} + +} // anonymous namespace + +// We could have caching here so that we don't redo this work for two identical +// convolutions. Unfortunately our cache key would have to be a tuple +// containing the protos passed to this function, and we have no utility for +// hashing protos. We could write our own hash functions, but they'd silently +// break if we ever added a field to one of the protos. Perhaps we could hack +// using the binary-encoded proto as the hash key, on the assumption that two +// protos being binary-equal is a sufficient, if not necessary, condition for +// proper equality. But that would still leave us open to having unnecessary +// cache misses and doing extra work. Overall, caching doesn't seem worth the +// trouble, but we may want to revisit this if we ever find a model where +// caching would speed up compilation a lot. +StatusOr +CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) { + // TODO(timshen): for now only check fp16. It can be expanded to other types, + // with some work on the HLO routines. + const bool cross_check_enabled = + instr->shape().tuple_shapes(0).element_type() == xla::F16; + + // Don't run this function concurrently on the same GPU. + // + // This is a bit of a hack and doesn't protect us against arbitrary concurrent + // use of a GPU, but it's sufficient to let us compile two HLO modules + // concurrently and then run them sequentially. + tensorflow::mutex_lock lock = LockGpu(stream_exec_); + + // Make sure any previous activity on this executor is done. We don't want to + // interfere with programs that are still running on the GPU. + if (!stream_exec_->SynchronizeAllActivity()) { + return InternalError("Failed to synchronize GPU for autotuning."); + } + + // Create a stream for us to do our work on. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); + allocator = &*se_allocator; + } + + const auto initialize_buffer = [&stream, cross_check_enabled]( + DeviceMemoryBase buffer) { + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. + CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); + size_t left_over_bytes = buffer.size() % 4; + CHECK_EQ(0, left_over_bytes % 2); + + constexpr float kBroadcastedConstant = 0.1f; + static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; + uint32 bits; + static_assert(sizeof(bits) == sizeof(halfs), ""); + memcpy(&bits, halfs, sizeof(bits)); + + size_t aligned_size = buffer.size() / 4 * 4; + stream.ThenMemset32(&buffer, bits, aligned_size); + + DeviceMemoryBase left_over( + static_cast(buffer.opaque()) + aligned_size, left_over_bytes); + stream.ThenMemcpy(&left_over, halfs, left_over_bytes); + } else { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream.ThenMemZero(&buffer, buffer.size()); + } + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + std::vector operand_buffers; + for (const auto* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); + } + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); + + se::dnn::ProfileResult best_result; + int64 best_result_bytes_used = 0; + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config()); + + optional comparator; + // Use the first algorithm that's supported as reference. There isn't a + // particular reason to use it, as any algorithm sufficies. It doesn't make + // this algorithm considered correct, though. + optional first_algorithm; + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " + << instr->ToString(); + + backend_config.set_algorithm(alg.algo_id()); + backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); + TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + bool launch_ok = + RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + &scratch_allocator, &stream, &profile_result) + .ok(); + + if (launch_ok && profile_result.is_valid()) { + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + if (comparator.has_value()) { + StatusOr result = comparator->CompareEqual( + se::DeviceMemory(result_buffer)); + if (!result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << result.status(); + CHECK(!crash_on_checking_failure); + } else if (!result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + } + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(result_buffer), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " + << comp.status() << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); + } + } + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) + << " succeeded, taking " << profile_result.elapsed_time_in_ms() + << "ms and using " << NumBytesToString(scratch_bytes_used) + << " of scratch (Best result: " + << best_result.elapsed_time_in_ms() << "ms, " + << NumBytesToString(best_result_bytes_used) << " of scratch)"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + best_result_bytes_used = scratch_bytes_used; + } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; + } + } + if (best_result.is_valid()) { + VLOG(2) << "Best algorithm for " << instr->ToString() << ": " + << AlgorithmToString(best_result.algorithm()) << ", takes " + << best_result.elapsed_time_in_ms() << "ms, and uses " + << best_result_bytes_used << "B of scratch memory."; + return AutotuneResult{best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used, + absl::Milliseconds(best_result.elapsed_time_in_ms())}; + } + + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString()); +} + +StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { + CHECK(IsCustomCallToDnnConvolution(*instr)); + + StatusOr best_algo_or = + PickBestAlgorithm(Cast(instr)); + if (!best_algo_or.ok()) { + LOG(ERROR) << best_algo_or.status(); + return false; + } + + auto best_algo = std::move(best_algo_or).ValueOrDie(); + VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm + << " and " << NumBytesToString(best_algo.scratch_bytes) + << " of scratch memory: " << instr->ToString() + << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; + + // Replace instr with a new CustomCall which has the correct algorithm, and + // whose output shape has the appropriate amount of scratch memory. + HloComputation* computation = instr->parent(); + Shape new_call_shape = ShapeUtil::MakeTupleShape( + {instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + instr->backend_config()); + backend_config.set_algorithm(best_algo.algorithm); + backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); + + HloInstruction* new_call = computation->AddInstruction( + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + + VLOG(1) << "Replacing convolution " << instr->ToString() << " with " + << new_call->ToString(); + + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); + + // Repackage new_call so it has the same shape as the original call, namely + // (conv_result, u8[0]). + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call_shape.tuple_shapes(0), new_call, 0)), + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({})))})); + + TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); + return true; +} + +StatusOr CudnnConvAlgorithmPicker::RunOnComputation( + HloComputation* computation) { + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + + bool changed = false; + for (auto* instr : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); + changed |= result; + } + return changed; +} + +StatusOr CudnnConvAlgorithmPicker::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h new file mode 100644 index 0000000000000000000000000000000000000000..642af787afc71586d722ecc7e529ed8b3fa64d33 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ + +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for +// each and adding explicit scratch space to the CustomCalls. +class CudnnConvAlgorithmPicker : public HloModulePass { + public: + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator, Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} + + absl::string_view name() const override { + return "cudnn-conv-algorithm-picker"; + } + + StatusOr Run(HloModule* module) override; + + private: + struct AutotuneResult { + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + absl::Duration runtime; + }; + + StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnInstruction(HloInstruction* instr); + StatusOr PickBestAlgorithm(HloCustomCallInstruction* instr); + + se::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc new file mode 100644 index 0000000000000000000000000000000000000000..5aa4f839f4be5f1060480fea98775f8ffada0bdd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.cc @@ -0,0 +1,243 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace gpu { + +// We won't pad a conv if doing so increases the total number of bytes in the +// lhs, rhs, or result by more than this amount. +// +// TODO(jlebar): This number was tuned experimentally. It represents a +// compromise on our current benchmarks; it speeds some up significantly, and +// doesn't slow any down. But we can observe by changing this value that +// there's additional room for speedups. Achieving those speedups without +// also slowing other things down will likely require a more sophisticated +// heuristic, possibly some form of auto-tuning. +static constexpr double kMaxBytesTouchedIncrease = 1.35; + +// Creates and returns an HLO that zero-pads one or more dimensions in the given +// instruction so that its shape is equal to the given shape. +// +// Padding is added to the end of each relevant dimension. +// +// If the instruction already has the given shape, simply returns it without an +// intervening pad. +static HloInstruction* PadInstruction(HloInstruction* instr, + const Shape& new_shape) { + HloComputation* comp = instr->parent(); + + const Shape& shape = instr->shape(); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); + + PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); + + bool added_padding = false; + for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { + if (shape.dimensions(dim) == new_shape.dimensions(dim)) { + continue; + } + CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + new_shape.dimensions(dim) - shape.dimensions(dim)); + added_padding = true; + } + + if (!added_padding) { + return instr; + } + return comp->AddInstruction( + HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); +} + +// Modifies the given convolution to have the given LHS/RHS/result shapes. +static Status PadConv(HloCustomCallInstruction* conv, + const Shape& new_lhs_shape, const Shape& new_rhs_shape, + const Shape& new_result_shape) { + CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) + << "conv must use 0 scratch bytes, i.e. this pass must be run " + "before CudnnConvAlgorithmPicker."; + + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + auto* new_lhs = PadInstruction(lhs, new_lhs_shape); + auto* new_rhs = PadInstruction(rhs, new_rhs_shape); + const Shape& result_shape = conv->shape().tuple_shapes(0); + CHECK(new_lhs != lhs || new_rhs != rhs) + << "We should have had to pad either LHS or RHS."; + + auto add = [&](std::unique_ptr new_instr) { + return conv->parent()->AddInstruction(std::move(new_instr)); + }; + + Shape new_conv_shape = ShapeUtil::MakeTupleShape( + {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); + auto* new_conv = + add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); + + // Slice the new conv result if necessary, keeping in mind that new_conv has + // tuple shape (new_result_shape, u8[0]). + if (!ShapeUtil::Equal(result_shape, new_result_shape)) { + std::vector start_indices(result_shape.dimensions_size(), 0); + std::vector end_indices(result_shape.dimensions().begin(), + result_shape.dimensions().end()); + std::vector strides(result_shape.dimensions_size(), 1); + + auto* new_conv_result = add( + HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); + auto* empty_temp_buffer = + add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + auto* sliced_result = add(HloInstruction::CreateSlice( + result_shape, new_conv_result, start_indices, end_indices, strides)); + new_conv = + add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); + } + + VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " + << new_conv->ToString(); + return conv->parent()->ReplaceInstruction(conv, new_conv); +} + +static StatusOr PadForTensorCores(HloCustomCallInstruction* conv) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv)); + const auto& dnums = conv->convolution_dimension_numbers(); + auto* lhs = conv->mutable_operand(0); + auto* rhs = conv->mutable_operand(1); + const Shape& result_shape = conv->shape().tuple_shapes(0); + + // Nothing to do on non-f16 convolutions. + if (result_shape.element_type() != PrimitiveType::F16) { + return false; + } + + // TODO(timshen): Don't skip forward-activation convs if we find a benchmark + // where there's a speedup. + if (kind == CudnnConvKind::kForwardActivation) { + return false; + } + + Shape new_lhs_shape = lhs->shape(); + Shape new_rhs_shape = rhs->shape(); + Shape new_result_shape = conv->shape().tuple_shapes(0); + + // new_{input,filter_output}_shape points to the appropriate one of + // new_{lhs,rhs,result}_shape. + Shape* new_input_shape; + Shape* new_filter_shape; + Shape* new_output_shape; + std::tie(new_input_shape, new_filter_shape, new_output_shape) = [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return std::make_tuple(&new_lhs_shape, &new_rhs_shape, + &new_result_shape); + case CudnnConvKind::kBackwardInput: + return std::make_tuple(&new_result_shape, &new_rhs_shape, + &new_lhs_shape); + case CudnnConvKind::kBackwardFilter: + return std::make_tuple(&new_lhs_shape, &new_result_shape, + &new_rhs_shape); + } + }(); + + // If there are 3 input features and 32 or 64 output features, pad the input + // features to 4. Otherwise, try padding to multiples of 8 and check that + // this doesn't make any of the conv buffers too much larger. + auto input_features = + new_input_shape->dimensions(dnums.input_feature_dimension()); + auto output_features = + new_output_shape->dimensions(dnums.output_feature_dimension()); + if (input_features == 3 && (output_features == 32 || output_features == 64)) { + new_input_shape->set_dimensions(dnums.input_feature_dimension(), 4); + new_filter_shape->set_dimensions(dnums.kernel_input_feature_dimension(), 4); + } else { + auto pad_dim = [](Shape* s, int64 dim) { + s->set_dimensions(dim, RoundUpToNearest(s->dimensions(dim), 8)); + }; + pad_dim(new_input_shape, dnums.input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_input_feature_dimension()); + pad_dim(new_filter_shape, dnums.kernel_output_feature_dimension()); + pad_dim(new_output_shape, dnums.output_feature_dimension()); + + // Check that padding wouldn't increase the total bytes read/written by this + // operation too much. + auto check_size_increase = [&](const Shape& old_shape, + const Shape& new_shape) { + int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); + int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); + if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { + return true; + } + VLOG(3) + << "Not padding convolution; doing so would change input / result " + "shape from " + << ShapeUtil::HumanString(old_shape) << " to " + << ShapeUtil::HumanString(new_shape) << ", a size increase of " + << new_bytes / static_cast(old_bytes) << "x > " + << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); + return false; + }; + + if (!check_size_increase(lhs->shape(), new_lhs_shape) || + !check_size_increase(rhs->shape(), new_rhs_shape) || + !check_size_increase(result_shape, new_result_shape)) { + return false; + } + } + + if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && + ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { + VLOG(3) << "No need to pad features of " << conv->ToString(); + return false; + } + + // OK, let's do the transformation! + TF_RETURN_IF_ERROR( + PadConv(conv, new_lhs_shape, new_rhs_shape, new_result_shape)); + return true; +} + +static std::vector GetRelevantConvs( + HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(Cast(instr)); + } + } + return convs; +} + +StatusOr CudnnConvPadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(conv)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h new file mode 100644 index 0000000000000000000000000000000000000000..d4e51e86c1bf2c1f9aef2eed642604092033a538 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Adds padding to cudnn convolutions to make them run faster on GPUs with +// tensor cores. +// +// - f16 convolutions are padded to have input/output channel dimensions that +// are multiples of 8, so that we can use tensor cores. +// +// - f16 convolutions with 3 input channels and 32 or 64 output channels are +// padded to 4 input channels. There's a special-cased cudnn algorithm just +// for this. +// +// Don't run this pass on GPUs without tensor cores -- it will make them slower! +// +// TODO(jlebar): Also pad dots. +class CudnnConvPadForTensorCores : public HloModulePass { + public: + absl::string_view name() const override { return "cudnn-conv-pad-for-speed"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..af9303a5b761b99705945f1c02303156e3f874de --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class CudnnConvPadForTensorCoresTest : public HloTestBase {}; + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + filter = f16[2,2,41,40] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + + SCOPED_TRACE(module->ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,41] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16ForwardConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + filter = f16[2,2,40,41] parameter(1) + ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvForwardCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _)); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + output = f16[10,20,30,40] parameter(0) + filter = f16[2,2,41,40] parameter(1) + result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardInput" + ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardInputCallTarget, op::Parameter(0), + op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,41] parameter(0) + output = f16[10,20,30,40] parameter(1) + result = (f16[2,2,41,40], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Pad(op::Parameter(0), _), op::Parameter(1)))), + _))); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,40] parameter(0) + output = f16[10,20,30,41] parameter(1) + result = (f16[2,2,40,41], u8[0]) custom-call(input, output), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convBackwardFilter" + ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::GetTupleElement(op::Tuple( + op::Slice(op::GetTupleElement(op::CustomCall( + kCudnnConvBackwardFilterCallTarget, + op::Parameter(0), op::Pad(op::Parameter(1), _)))), + _))); +} + +TEST_F(CudnnConvPadForTensorCoresTest, PadInputFeatures3To4) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + input = f16[10,20,30,3] parameter(0) + filter = f16[2,2,3,32] parameter(1) + ROOT result = (f16[10,20,30,32], u8[0]) custom-call(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f, + custom_call_target="__cudnn$convForward" + })") + .ValueOrDie(); + EXPECT_TRUE(CudnnConvPadForTensorCores().Run(module.get()).ValueOrDie()); + auto* root = module->entry_computation()->root_instruction(); + + SCOPED_TRACE(module->ToString()); + EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, + op::Pad(op::Parameter(0), _), + op::Pad(op::Parameter(1), _))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), + ShapeUtil::MakeShape(F16, {10, 20, 30, 4}))); + EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), + ShapeUtil::MakeShape(F16, {2, 2, 4, 32}))); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a09d4d4716950a09d65dd093272482d55ac5c27 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.cc @@ -0,0 +1,428 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +namespace { +bool IsForwardConvolutionCanonical(const HloInstruction& conv) { + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); + return window_util::HasSymmetricPadding(conv.window()) && + !window_util::HasNegativePadding(conv.window()) && + !window_util::HasDilation(conv.window()); +} + +// If the (positive and negative) padding on the input operand of a convolution +// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and +// dilation), returns kPad and/or kSlice instructions that explicitly apply the +// padding; otherwise returns the original input operand. When there is both +// positive padding (including dilation) and negative padding, we insert both +// kPad and kSlice. Modifies 'conv_window' accordingly if any padding was moved +// into a kPad or kSlice op. +HloInstruction* MaybePaddedAndSlicedInput( + Window* conv_window, const ConvolutionDimensionNumbers& conv_dnums, + HloInstruction* input) { + HloComputation* computation = input->parent(); + if (!window_util::HasSymmetricPadding(*conv_window) || + window_util::HasBaseDilation(*conv_window)) { + // If padding is uneven or has dilation, we insert a kPad instruction that + // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. + PaddingConfig padding_config = + MakeNoPaddingConfig(input->shape().dimensions_size()); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); + if (conv_window->dimensions(i).padding_low() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_low( + conv_window->dimensions(i).padding_low()); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() > 0) { + padding_config.mutable_dimensions(dim)->set_edge_padding_high( + conv_window->dimensions(i).padding_high()); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + if (conv_window->dimensions(i).base_dilation() != 1) { + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window->dimensions(i).base_dilation() - 1); + conv_window->mutable_dimensions(i)->set_base_dilation(1); + } + } + PrimitiveType element_type = input->shape().element_type(); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + input = MakePadHlo(input, padding, padding_config).ValueOrDie(); + } + + if (window_util::HasNegativePadding(*conv_window)) { + // If the window has negative padding, insert a kSlice that explicitly + // applies negative padding. + // + // For each dimension, initialize the start index to 0 and the limit index + // to the size of that dimension. + std::vector start_indices(input->shape().dimensions_size(), 0); + std::vector limit_indices(input->shape().dimensions().begin(), + input->shape().dimensions().end()); + std::vector strides(input->shape().dimensions_size(), 1); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); + // If dimension "dim" has negative padding, increase the start index or + // decrement the limit index by the amount of negative padding. + if (conv_window->dimensions(i).padding_low() < 0) { + start_indices[dim] += -conv_window->dimensions(i).padding_low(); + conv_window->mutable_dimensions(i)->set_padding_low(0); + } + if (conv_window->dimensions(i).padding_high() < 0) { + limit_indices[dim] -= -conv_window->dimensions(i).padding_high(); + conv_window->mutable_dimensions(i)->set_padding_high(0); + } + } + + input = + MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); + } + + return input; +} + +// If the padding on the kernel operand of a convolution can't be folded into a +// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that +// explicitly applies the padding; otherwise returns the original kernel +// operand. +HloInstruction* MaybePaddedKernel(const Window& conv_window, + const ConvolutionDimensionNumbers& conv_dnums, + HloInstruction* kernel) { + if (!window_util::HasWindowDilation(conv_window)) { + return kernel; + } + + // Compute the shape and padding config of the pad to be inserted. + PaddingConfig padding_config; + for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { + padding_config.add_dimensions(); + } + for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.kernel_spatial_dimensions(i); + padding_config.mutable_dimensions(dim)->set_interior_padding( + conv_window.dimensions(i).window_dilation() - 1); + } + + HloComputation* computation = kernel->parent(); + PrimitiveType element_type = kernel->shape().element_type(); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); +} +} // namespace + +bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution( + HloInstruction* conv) { + if (IsForwardConvolutionCanonical(*conv)) { + return false; + } + + // Insert slices and/or pads between the convolution and its input and/or + // kernel operand. + Window new_conv_window = conv->window(); + HloInstruction* new_input = MaybePaddedAndSlicedInput( + &new_conv_window, conv->convolution_dimension_numbers(), + conv->mutable_operand(0)); + HloInstruction* new_kernel = + MaybePaddedKernel(new_conv_window, conv->convolution_dimension_numbers(), + conv->mutable_operand(1)); + + // Remove the window dilation from convolution's window field. These paddings + // are made explicit with the pads inserted by MaybePaddedKernel(). + for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { + WindowDimension* dim = new_conv_window.mutable_dimensions(i); + + // The size of the kernel may have changed so update the Window to match. + dim->set_size(new_kernel->shape().dimensions( + conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); + dim->set_window_dilation(1); + } + + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + VLOG(1) << "Canonicalizing forward conv"; + std::vector operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), operands)); + new_conv->set_window(new_conv_window); + VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " + << new_conv->ToString(); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); + return true; +} + +namespace { +void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) { + window_dim->set_padding_low(window_dim->padding_low() + delta); +} + +void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { + window_dim->set_padding_high(window_dim->padding_high() + delta); +} +} // namespace + +bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( + HloInstruction* backward_conv) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); + if (window_util::HasSymmetricPadding(backward_conv->window())) { + return false; + } + + // A backward filter convolution with uneven padding can be canonicalized to + // one with even padding by padding the activations (input) beforehand. For + // example, + // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2) + // is equivalent to + // ABCD0 = Pad(ABCD, padding_high=1) + // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) + // We choose the lesser of padding_low and padding_high as the new padding. + HloInstruction* input = backward_conv->mutable_operand(0); + Window new_backward_conv_window = backward_conv->window(); + // input_padding_config is the config of the kPad to be inserted. + PaddingConfig input_padding_config = + MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); + ConvolutionDimensionNumbers backward_conv_dnums = + backward_conv->convolution_dimension_numbers(); + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + if (padding_low < 0 || padding_high < 0) { + // TODO(b/32744257): The following canonicalization wouldn't remove + // negative padding in a backward convolution, and would therefore cause + // cuDNN convolution (which doesn't support negative padding) to fail. + return false; + } + // Compute the new, even padding for the backward conv operation. + int64 new_conv_padding = std::min(padding_low, padding_high); + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); + input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( + padding_low - new_conv_padding); + input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( + padding_high - new_conv_padding); + + // Since we move some padding from the backward convolution to the kPad, we + // need to accordingly reduce the padding amount of the backward convolution + // and its inner forward convolution. + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); + } + + // Create a new backward convolution replacing the old one. + HloComputation* computation = backward_conv->parent(); + HloInstruction* output = backward_conv->mutable_operand(1); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); + HloInstruction* padded_input = + MakePadHlo(input, padding, input_padding_config).ValueOrDie(); + + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); + + VLOG(1) << "Canonicalizing backward filter conv"; + VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " + << new_backward_conv->ToString(); + + TF_CHECK_OK( + computation->ReplaceInstruction(backward_conv, new_backward_conv)); + return true; +} + +bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution( + HloInstruction* backward_conv) { + if (window_util::HasSymmetricPadding(backward_conv->window())) { + return false; + } + + Window new_backward_conv_window = backward_conv->window(); + ConvolutionDimensionNumbers backward_conv_dnums = + backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + if (padding_low < 0 || padding_high < 0) { + // TODO(b/32744257): The following canonicalization wouldn't remove + // negative padding in a backward convolution, and would therefore cause + // cuDNN convolution (which doesn't support negative padding) to fail. + return false; + } + // If the backward convolution has uneven padding on the activations, we + // move some padding on the larger end to "internal" padding, so that the + // backward convolution produces larger activations which get sliced later. + // + // For example, suppose we have a non-canonical HLO + // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1)) + // where the amount of padding low is larger, we can canonicalize it to + // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) + // [A] = Slice([B A]) + if (padding_low > padding_high) { + IncreasePaddingLowBy(padding_high - padding_low, + new_backward_conv_window.mutable_dimensions(i)); + } else if (padding_low < padding_high) { + IncreasePaddingHighBy(padding_low - padding_high, + new_backward_conv_window.mutable_dimensions(i)); + } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); + } + + // Create a new backward convolution replacing the old one. + HloComputation* computation = backward_conv->parent(); + HloInstruction* output = backward_conv->mutable_operand(0); + HloInstruction* filter = backward_conv->mutable_operand(1); + + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. + HloInstruction* new_backward_conv = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); + + // Slice the new backward convolution. + // + // Initialize start_indices and limit_indices as no slicing. + std::vector start_indices(new_backward_conv->shape().dimensions_size(), + 0LL); + std::vector limit_indices( + new_backward_conv->shape().dimensions().begin(), + new_backward_conv->shape().dimensions().end()); + std::vector strides(new_backward_conv->shape().dimensions_size(), 1LL); + for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { + int64 padding_low = backward_conv->window().dimensions(i).padding_low(); + int64 padding_high = backward_conv->window().dimensions(i).padding_high(); + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + if (padding_low > padding_high) { + // If the amount of low padding (of the old backward convolution) is + // larger, we internally pad the low end of the activations and slice + // internal padding out here. + start_indices[dim] += padding_low - padding_high; + } else if (padding_low < padding_high) { + // If the amount of high padding is larger, we slice out the internal + // padding on the high end. + limit_indices[dim] -= padding_high - padding_low; + } + } + + // Replace the old backward convolution with the slice. + Shape slice_shape = + ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, + limit_indices, strides) + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); + + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); + + VLOG(1) << "Canonicalizing backward input conv"; + VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " + << new_tuple->ToString(); + + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); + return true; +} + +StatusOr CudnnConvPaddingLegalization::RunOnComputation( + HloComputation* computation) { + bool changed = false; + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(Cast(instr)); + } + } + for (HloCustomCallInstruction* instruction : convs) { + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instruction)); + changed |= [&] { + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + return CanonicalizeForwardConvolution(instruction); + case CudnnConvKind::kBackwardInput: + return CanonicalizeBackwardInputConvolution(instruction); + case CudnnConvKind::kBackwardFilter: + return CanonicalizeBackwardFilterConvolution(instruction); + } + }(); + } + return changed; +} + +StatusOr CudnnConvPaddingLegalization::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h new file mode 100644 index 0000000000000000000000000000000000000000..7d1b075517fb285222506e0420984906579e681f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// An HLO pass that canonicalizes convolution instructions for GPU codegen. It +// inserts Pad instructions before Convolution instructions with uncanonicalized +// padding, so that they can be lowered to cuDNN convolution. +class CudnnConvPaddingLegalization : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-conv-padding-legalization"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + // Returns if any changes are made to the parent computation. + bool CanonicalizeForwardConvolution(HloInstruction* conv); + bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); + bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..e81850db69edced29ea31bb2a526b0503bf8a453 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -0,0 +1,581 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace gpu { + +namespace { + +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count, + const OpMetadata& metadata) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + custom_call->set_metadata(metadata); + return custom_call; +} + +bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + if (dnums.input_spatial_dimensions_size() > 3) { + return false; + } + + // CuDNN does not accept zero-element arguments + if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { + return false; + } + + // CuDNN can perform either cross correlation (no reversal), + // or convolution (all dimensions reversed). + if (dnums.input_spatial_dimensions_size() == 2 + ? !window_util::AllOrNoneReversed(conv->window()) + : window_util::HasWindowReversal(conv->window())) { + return false; + } + return true; +} + +// Try to match a backward filter pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple MatchBackwardFilter( + HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + if (conv->feature_group_count() > 1) { + return no_match_result; + } + // Step 1: match the instruction pattern without considering the paddings and + // dimension numbers just yet. We may need some generic pattern matcher + // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h + // + // Backward filter convolution is implemented in XLA as the forward + // convolution of padded activations and dilated gradients. Padding on + // activations and dilation on gradients are specified in the "window" field + // of the forward convolution. + // + // activations gradients + // \ / + // v v + // Convolution + // conv + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + + // Step 2: match paddings and dimension numbers of the forward convolution. + const ConvolutionDimensionNumbers& conv_dnums = + conv->convolution_dimension_numbers(); + auto input_batch_dim = conv_dnums.input_batch_dimension(); + auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); + auto output_batch_dim = conv_dnums.output_batch_dimension(); + auto output_feature_dim = conv_dnums.output_feature_dimension(); + auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); + + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.base_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no base (LHS) dilation."; + return no_match_result; + } + if (window_dim.padding_low() < 0) { + VLOG(1) << "Padding low should be non-negative."; + return no_match_result; + } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } + // Padding high will be checked in Step 3. + } + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { + VLOG(1) << conv->ToString() + << " is a regular forward convolution. No need " + "to fold it to a backward filter convolution."; + return no_match_result; + } + + // Step 3: fuse the matched HLOs into a backward convolution instruction. + // + // Compute the window of the backward convolution. + Window backward_conv_window; + for (int i = 0; i < input_spatial_dims.size(); ++i) { + WindowDimension* dim = backward_conv_window.add_dimensions(); + // The window size of the backward convolution equals the output size of the + // forward convolution. + int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); + dim->set_size(filter_size); + // The window stride equals the window dilation of the forward convolution. + dim->set_stride(conv->window().dimensions(i).window_dilation()); + // The window's low padding is the same as the low padding of the + // activations. + dim->set_padding_low(conv->window().dimensions(i).padding_low()); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + + int64 input_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); + int64 output_size = conv->window().dimensions(i).size(); + // Compute the range of the amount of valid high padding. We first compute + // min_padding_high, the amount of padding on the right/bottom to ensure the + // last patch ends at the border, i.e., + // + // input_size + dim->padding_low() + min_padding_high + // = (output_size - 1) * stride + filter_size + // + // Because convolution ignores trailing incomplete windows, any amount of + // padding high from min_padding_high to min_padding_high+stride-1 + // (max_padding_high) has the same effect. + int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); + int64 min_padding_high = + padded_input_size - input_size - dim->padding_low(); + int64 max_padding_high = min_padding_high + dim->stride() - 1; + CHECK_GE(dim->padding_low(), 0); + // In practice, since cuDNN convolution only supports even padding, we make + // the amount of high padding the same as the amount of low padding as long + // as it is between min_padding_high and max_padding_high. If it is not in + // that range, we pick the one that's closest to dim->padding_low() and let + // CudnnConvPaddingLegalization canonicalize the resultant backward + // convolution later. Picking the closest one minimizes the cost of the kPad + // instruction to be inserted by CudnnConvPaddingLegalization. + if (dim->padding_low() >= min_padding_high && + dim->padding_low() <= max_padding_high) { + dim->set_padding_high(dim->padding_low()); + } else { + if (dim->padding_low() < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + if (dim->padding_high() < 0) { + LOG(ERROR) + << "Fusing this pattern to backward filter convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the weight gradients, which is not " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to " + "unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // Restore the dimension numbers of the backward convolution from the forward + // convolution. The two activation dimensions are reversed (batch and + // feature). + ConvolutionDimensionNumbers backward_conv_dnums; + backward_conv_dnums.set_input_batch_dimension(input_feature_dim); + backward_conv_dnums.set_input_feature_dimension(input_batch_dim); + for (int i = 0; i < input_spatial_dims.size(); ++i) { + backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); + } + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); + } + // The dimension numbering of the output of the forward convolution (before + // transposition) is the same as that of the activations (according to the + // semantics of kConvolution). The batch dimension of the activations should + // be treated as the input feature dimension, and the feature dimension should + // be treated as the output feature. + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); + for (int i = 0; i < output_spatial_dims.size(); ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); + } + + return std::make_tuple(true, backward_conv_window, backward_conv_dnums); +} + +// Try to match a backward input pattern that contains "conv". +// Precondition: "conv" is a kConvolution. +std::tuple +MatchBackwardInput(HloInstruction* conv) { + const auto no_match_result = + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); + + // TODO(b/119479517): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. + if (conv->feature_group_count() > 1) { + return no_match_result; + } + + // Match instruction pattern. + CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); + HloInstruction* reverse_filter = conv->mutable_operand(1); + ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; + } + + // Match padding and dilation of the forward convolution. + for (const WindowDimension& window_dim : conv->window().dimensions()) { + if (window_dim.stride() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have stride of 1."; + return no_match_result; + } + if (window_dim.window_dilation() != 1) { + VLOG(1) << "Forward convolution's window " + << conv->window().ShortDebugString() + << " should have no window dilation."; + return no_match_result; + } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } + } + + const auto& input_spatial_dims = dnums.input_spatial_dimensions(); + const auto& output_spatial_dims = dnums.output_spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); + CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); + + const Window& old_window = conv->window(); + Window new_window = old_window; + for (size_t i = 0; i < input_spatial_dims.size(); ++i) { + // Restore backward convolution's padding config from the matched pattern. + // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc + // for how we convert backward input convolution to a variant of forward + // convolution. + // + // The stride of the backward convolution + // = the base dilation factor of the forward convolution + auto dim = new_window.mutable_dimensions(i); + dim->set_stride(old_window.dimensions(i).base_dilation()); + dim->set_base_dilation(1); + + // The low padding = kernel_size - 1 - low padding on the gradients + // Make sure the low padding is not negative. + auto kernel_size = old_window.dimensions(i).size(); + auto backward_padding_low = + kernel_size - 1 - old_window.dimensions(i).padding_low(); + if (backward_padding_low < 0) { + LOG(ERROR) + << "The low padding of the backward convolution would be negative (" + << backward_padding_low + << "), which isn't supported by CudnnConvPaddingLegalization " + "for now (b/32744257)."; + return no_match_result; + } + dim->set_padding_low(backward_padding_low); + + // Compute the range of the amount of padding on the right/bottom of the + // activations. XLA's convolution requires all patches to be within the + // padded base. This gives us flexiblity to choose the amount of high + // padding from a set of values without changing the result of the backward + // convolution. The minimum amount (min_padding_high) makes the last patch + // end at the border. The maximum amount (max_padding_high) equals + // min_padding_high+stride-1 -- max_padding_high+1 would cause the output + // size to change. + auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); + auto output_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); + auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); + auto total_pad_size = padded_input_size - unpadded_input_size; + auto min_padding_high = total_pad_size - backward_padding_low; + auto max_padding_high = min_padding_high + dim->stride() - 1; + + if (backward_padding_low >= min_padding_high && + backward_padding_low <= max_padding_high) { + // In the best case (most likely), if backward_padding_low is in the range + // of the amounts of valid high padding, we choose backward_padding_low + // because cuDNN supports even padding only. + dim->set_padding_high(backward_padding_low); + } else { + // Otherwise, we choose the amount that's closest to backward_padding_low, + // and CudnnConvPaddingLegalization will later insert kSlice + // instructions to enforce even padding. + // + // For example, consider the backward convolution pattern + // + // ab xy + // | pad | reverse + // .a.b yx + // \ / + // ABC + // + // The amount of low padding on activations (in backward convolution) is + // backward_padding_low = kernel_size - 1 - forward_padding_low + // = 2 - 1 - 1 = 0 + // + // The amount of padding high must be between 1 and 2, in order to make + // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in + // the range of [1,2], so we pick the closest valid amount of padding + // high, which is 1 in this case. Therefore, we fuse the above pattern to + // + // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) + if (backward_padding_low < min_padding_high) { + dim->set_padding_high(min_padding_high); + } else { + dim->set_padding_high(max_padding_high); + } + } + // CudnnConvPaddingLegalization doesn't handle backward input + // convolution with negative padding for now. So fall back to unfused + // convolution in case of negative padding. For example, + // ABCD = Conv(abc, reverse(xy), padding_high=2) + // could be fused to + // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) + // with positive padding low but negative padding high. + if (dim->padding_high() < 0) { + LOG(ERROR) << "Fusing this pattern to backward convolution would cause " + "negative padding (" + << dim->padding_high() + << ") on right/bottom of the activations, which is not " + "supported by CudnnConvPaddingLegalization (b/32744257). " + "Falling back to unfused convolution for instruction: " + << conv->ToString(); + return no_match_result; + } + } + + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { + // Create a double-reverse, which is a nop. + HloComputation* c = conv->parent(); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); + } + + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); +} + +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; +} + +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); + + HloInstruction* custom_call = [&]() -> HloInstruction* { + bool match; + Window window; + ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; + + std::tie(match, window, dnums) = MatchBackwardFilter(conv); + if (match) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count(), + conv->metadata()); + } + + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); + if (match) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count(), conv->metadata()); + } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + conv->window(), + conv->convolution_dimension_numbers(), + conv->feature_group_count(), conv->metadata()); + } + + return nullptr; + }(); + + if (custom_call == nullptr) { + return false; + } + + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << custom_call->ToString(); + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + conv, + HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr RunOnComputation(HloComputation* computation) { + std::vector convs; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); + changed |= result; + } + return changed; +} +} // namespace + +StatusOr CudnnConvRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..d8ec72c27bab8912d0dc2aeead114eb010b87b78 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrites plain convolutions, backwards-filter convolutions, and +// backwards-input convolutions into CustomCall HLOs that call into cuDNN. +class CudnnConvRewriter : public HloModulePass { + public: + absl::string_view name() const override { return "cudnn-conv-rewriter"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..443883a89f66a747def1049bc5afb53fec3c2409 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -0,0 +1,627 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class CudnnConvRewriterTest : public HloTestBase { + public: + CudnnConvRewriterTest() + : HloTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { + for (int i = 0; i < 2; ++i) { + WindowDimension* window_dim = default_conv_window_.add_dimensions(); + window_dim->set_size(1); + window_dim->set_stride(1); + window_dim->set_padding_low(0); + window_dim->set_padding_high(0); + window_dim->set_window_dilation(1); + window_dim->set_base_dilation(1); + } + // TF data shapes are by default in the NHWC order, and filter shape is by + // default in HWIO order. For backward filter convolution, we need to swap + // the batch and feature dimension in the activations, and treat the batch + // dimension in gradients as the input feature dimension in the filter. + // + // TODO(jingyue): Add more tests on NCHW input order, which TF also + // supports. + tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( + 3); + tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); + + tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); + tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); + tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1); + } + + protected: + bool RunPass(HloModule* module) { + return CudnnConvRewriter().Run(module).ValueOrDie(); + } + + // A convolution window with stride 1 and zero padding. The size fields are + // not set. + Window default_conv_window_; + ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_; + ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; +}; + +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { + HloComputation::Builder builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients")); + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_size(2); + conv_window.mutable_dimensions(1)->set_window_dilation(2); + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) + .ConsumeValueOrDie(), + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + + OpMetadata metadata; + metadata.set_op_name("foo"); + conv->set_metadata(metadata); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); + + // Check that metadata was preserved. + const auto& md_after_opt = + entry_computation->root_instruction()->operand(0)->metadata(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata)) + << md_after_opt.DebugString() << " vs " << metadata.DebugString(); +} + +TEST_F(CudnnConvRewriterTest, + BackwardFilterConvolveEquivalentToForwardConvolution) { + HloComputation::Builder builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients")); + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_size(3); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) + .ConsumeValueOrDie(), + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); +} + +// Extracted from block35 training. +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(35); + conv_window.mutable_dimensions(i)->set_padding_low(1); + conv_window.mutable_dimensions(i)->set_padding_high(1); + } + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); +} + +// Extracted from inception v3 training. +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(4); + conv_window.mutable_dimensions(i)->set_padding_high(-1); + conv_window.mutable_dimensions(i)->set_window_dilation(2); + } + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); +} + +TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(35); + // Uneven padding: padding_low=0, padding_high=1 + conv_window.mutable_dimensions(i)->set_padding_high(1); + } + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); +} + +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(7); + conv_window.mutable_dimensions(i)->set_padding_low(3); + conv_window.mutable_dimensions(i)->set_padding_high(3); + } + ConvolutionDimensionNumbers conv_dnums; + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); + conv_dnums.set_kernel_input_feature_dimension(0); + conv_dnums.set_kernel_output_feature_dimension(1); + conv_dnums.add_kernel_spatial_dimensions(2); + conv_dnums.add_kernel_spatial_dimensions(3); + + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); + for (int i = 0; i < 2; ++i) { + const WindowDimension& window_dim = custom_call->window().dimensions(i); + // Low padding of the backward input convolution + // = kernel_size - 1 - low padding on gradients. + EXPECT_EQ(3, window_dim.padding_low()); + EXPECT_EQ(3, window_dim.padding_high()); + EXPECT_EQ(1, window_dim.stride()); + EXPECT_EQ(1, window_dim.base_dilation()); + } +} + +// Convolve([abc], [x], base_dilation=2) +// = Convolve([abc], Reverse([x]), base_dilation=2) +// = BackwardInputConvolve([abc], [x], stride=2) +TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) { + auto builder = HloComputation::Builder(TestName()); + // NHWC dimension order. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // HWOI dimension order. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); + + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_base_dilation(2); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_) + .ConsumeValueOrDie(), + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); +} + +// BackwardInputConvolve([abc], [x], stride=1) is equivalent to +// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input +// convolution. +TEST_F(CudnnConvRewriterTest, + BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { + auto builder = HloComputation::Builder(TestName()); + // NHWC dimension order. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // HWOI dimension order. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) + .ConsumeValueOrDie(), + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); +} + +// Extracted from Inception V3 training. +// +// filter(HWIO) +// 3x3x192x320 +// | +// v +// gradients(NHWC) reverse +// 20x4x4x320 3x3x192x320 +// \ / +// \ / +// conv (NHWC) with padding (low=2,high=3,interior=1) +// 20x10x10x192 +// +// Gradients are padded unevenly. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(3); + conv_window.mutable_dimensions(i)->set_padding_low(2); + conv_window.mutable_dimensions(i)->set_padding_high(3); + // Interior padding = 1. + conv_window.mutable_dimensions(i)->set_base_dilation(2); + } + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); + for (int i = 0; i < 2; ++i) { + const WindowDimension& window_dim = custom_call->window().dimensions(i); + EXPECT_EQ(0, window_dim.padding_low()); + EXPECT_EQ(0, window_dim.padding_high()); + EXPECT_EQ(2, window_dim.stride()); + EXPECT_EQ(1, window_dim.base_dilation()); + } +} + +// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the +// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + for (int i = 0; i < 2; ++i) { + conv_window.mutable_dimensions(i)->set_size(3); + conv_window.mutable_dimensions(i)->set_padding_low(3); + conv_window.mutable_dimensions(i)->set_padding_high(2); + conv_window.mutable_dimensions(i)->set_base_dilation(2); + } + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); +} + +// Extracted from Resnet-50. +// +// For simplicity, we focus on the column dimension and ignore other dimensions. +// We use [?] to represent the shape instead of the content. +// +// Suppose operator FC does +// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame +// +// BC = BackwardInput(FC) does: +// [14] = conv([7], reverse([3]), +// padding_low=2, padding_high=1, base_dilation=2) +// +// We should fuse BC even though padding on activations is uneven, because +// CudnnConvPaddingLegalization will canonicalize the fusion HLO. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { + auto builder = HloComputation::Builder(TestName()); + // The gradients are in NCHW layout. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output")); + // The kernel is in HWIO layout. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); + forward_conv_col_dim->set_size(3); + forward_conv_col_dim->set_padding_low(2); + forward_conv_col_dim->set_padding_high(1); + forward_conv_col_dim->set_base_dilation(2); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + auto module = CreateNewVerifiedModule(); + const HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const WindowDimension& backward_conv_col_dim = + entry_computation->root_instruction()->operand(0)->window().dimensions(1); + EXPECT_EQ(0, backward_conv_col_dim.padding_low()); + EXPECT_EQ(1, backward_conv_col_dim.padding_high()); +} + +// For simplicity, we focus on the column dimension and ignore other dimensions. +// We use [?] to represent the shape instead of the content. +// +// Suppose operator FC does +// [3] = conv([4], [2], padding_low=1, padding_high=-1) +// +// BC = BackwardInput(FC) does: +// [4] = conv([3], reverse([2]), padding_high=2) +// +// We currently don't fuse BC because CudnnConvPaddingLegalization +// doesn't support negative padding on the gradients of backward convolution +// (b/32744257). +TEST_F(CudnnConvRewriterTest, + BackwardInputConvolveNegativePaddingHighOnActivations) { + auto builder = HloComputation::Builder(TestName()); + // The gradients are in NCHW layout. + HloInstruction* output = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); + // The kernel is in HWIO layout. + HloInstruction* kernel = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel")); + HloInstruction* reverse_kernel = builder.AddInstruction( + HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); + + Window conv_window = default_conv_window_; + WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); + forward_conv_col_dim->set_size(2); + forward_conv_col_dim->set_padding_high(2); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); + // Verify the convolution's shape is consistent with ShapeInference. + CHECK(ShapeUtil::Compatible( + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); + + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); +} + +// Check that we will materialize a reversed version of a constant in order to +// pattern-match a backwards input convolution. +TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) { + Array4D constant_arr(4, 4, 2, 2); + constant_arr.FillIota(0); + string constant_str = + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); + + const string module_str = absl::StrFormat(R"( + HloModule test + + ENTRY entry_computation { + param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) + constant = f32[4,4,2,2]{3,2,1,0} constant(%s) + ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), + window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, + dim_labels=bf01_01oi->bf01, feature_group_count=1 + })", + constant_str); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); + + EXPECT_TRUE(RunPass(m.get())); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..3425e1b4942aaf1011ba1bf1c50dd7e79c1f9807 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -0,0 +1,430 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional fusion; +}; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(absl::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +template +Status RunCudnnConvImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm()->algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << algorithm.algorithm()->tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + CHECK_GE(num_dimensions, 1); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + CHECK_EQ(primitive_util::NativeToPrimitiveType(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + + // If one dimension is reversed, we need to have all dimensions reversed (so + // we're doing convolution not cross correlation). + const bool dims_reversed = window.dimensions()[0].window_reversal(); + + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dims_reversed, dim.window_reversal()); + CHECK_EQ(dim.padding_low(), dim.padding_high()); + CHECK_EQ(dim.base_dilation(), 1) + << "cudnn does not support base dilation; it " + "must be made explicit with a kPad"; + CHECK_EQ(dim.window_dilation(), 1) + << "XLA does not support window dilation (although cudnn does); it " + "must be made explicit with a kPad"; + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(input_dl) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(filter_dl) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); + convolution_descriptor.set_convolution_not_crosscorr(dims_reversed); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(output_dl) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm()->algo_id(), + algorithm.algorithm_no_scratch()->algo_id()); + } + return Status::OK(); +} + +// Returns the cudnn convolution parameters generated from conv, which must be a +// custom-call to a cudnn convolution. +StatusOr GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer) { + CudnnConvParams params; + + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config()); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(conv)); + const auto& lhs_shape = conv->operand(0)->shape(); + const auto& rhs_shape = conv->operand(1)->shape(); + const auto& conv_result_shape = conv->shape().tuple_shapes(0); + + params.kind = kind; + params.window = &conv->window(); + params.dnums = &conv->convolution_dimension_numbers(); + params.feature_group_count = conv->feature_group_count(); + params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); + + switch (kind) { + case CudnnConvKind::kForward: + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + break; + case CudnnConvKind::kBackwardInput: + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + break; + case CudnnConvKind::kBackwardFilter: + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + break; + case CudnnConvKind::kForwardActivation: { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.mode = static_cast( + backend_config.activation_mode()); + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } + } + } + return params; +} + +} // anonymous namespace + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator, + stream, profile_result); +} + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + TF_ASSIGN_OR_RETURN(CudnnConvParams params, + GetCudnnConvParams(conv, operand_buffers, result_buffer)); + + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); + switch (output_primitive_type) { + case F16: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + case F32: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + case F64: + return RunCudnnConvImpl(params, scratch_allocator, stream, + profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..edbc75a94a1238540390b93f0fa5217852c7781f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Calls into cudnn to run the specified convolution. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConv(const HloCustomCallInstruction* conv, + absl::Span operand_buffers, + se::DeviceMemoryBase result_buffer, + se::ScratchAllocator* scratch_allocator, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc deleted file mode 100644 index 7125673887d28729287d67577bcfa06423f85611..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ /dev/null @@ -1,411 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/types/optional.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/platform/mutex.h" - -namespace xla { -namespace gpu { -namespace { - -using absl::optional; -using se::DeviceMemoryBase; -using se::dnn::AlgorithmConfig; -using se::dnn::AlgorithmDesc; - -class ScratchAllocator : public se::ScratchAllocator { - public: - ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - - int64 GetMemoryLimitInBytes(se::Stream* stream) override { - return 1LL << 32; // 4GB. TODO(jlebar): Tune this? - } - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - StatusOr> AllocateBytes(se::Stream* stream, - int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - -StatusOr> ScratchAllocator::AllocateBytes( - se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - absl::StrFormat( - "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false)); - total_allocated_bytes_ += byte_size; - - se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); - allocated_buffers_.push_back(std::move(allocated_buffer)); - return se::DeviceMemory(buffer_addr); -} - -std::vector GetAlgorithms(CudnnConvKind kind, - se::StreamExecutor* stream_exec) { - std::vector algorithms; - bool succ = false; - switch (kind) { - case CudnnConvKind::kBackwardFilter: - succ = - stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); - break; - case CudnnConvKind::kBackwardInput: - succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); - break; - case CudnnConvKind::kForward: - case CudnnConvKind::kForwardActivation: - succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); - break; - } - DCHECK(succ); - - return algorithms; -} - -string AlgorithmToString(const AlgorithmDesc& algo) { - if (algo.tensor_ops_enabled()) { - return absl::StrCat(algo.algo_id(), "+TC"); - } - return absl::StrCat(algo.algo_id()); -} - -string NumBytesToString(int64 bytes) { - return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", - bytes, "B)"); -} - -// Acquires a process-global lock on the device pointed to by the given -// StreamExecutor. -// -// This is used to prevent other XLA instances from trying to autotune on this -// device while we're using it. -tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - // se::Platform*s are global singletons guaranteed to live forever. - static auto* mutexes = - new std::map, - tensorflow::mutex>(); - - tensorflow::mutex_lock global_lock(mu); - auto it = mutexes - ->emplace(std::piecewise_construct, - std::make_tuple(stream_exec->platform(), - stream_exec->device_ordinal()), - std::make_tuple()) - .first; - return tensorflow::mutex_lock{it->second}; -} - -} // anonymous namespace - -// We could have caching here so that we don't redo this work for two identical -// convolutions. Unfortunately our cache key would have to be a tuple -// containing the protos passed to this function, and we have no utility for -// hashing protos. We could write our own hash functions, but they'd silently -// break if we ever added a field to one of the protos. Perhaps we could hack -// using the binary-encoded proto as the hash key, on the assumption that two -// protos being binary-equal is a sufficient, if not necessary, condition for -// proper equality. But that would still leave us open to having unnecessary -// cache misses and doing extra work. Overall, caching doesn't seem worth the -// trouble, but we may want to revisit this if we ever find a model where -// caching would speed up compilation a lot. -StatusOr> -CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - HloCustomCallInstruction* instr) { - // TODO(timshen): for now only check fp16. It can be expanded to other types, - // with some work on the HLO routines. - const bool cross_check_enabled = - instr->shape().tuple_shapes(0).element_type() == xla::F16; - - // Don't run this function concurrently on the same GPU. - // - // This is a bit of a hack and doesn't protect us against arbitrary concurrent - // use of a GPU, but it's sufficient to let us compile two HLO modules - // concurrently and then run them sequentially. - tensorflow::mutex_lock lock = LockGpu(stream_exec_); - - // Make sure any previous activity on this executor is done. We don't want to - // interfere with programs that are still running on the GPU. - if (!stream_exec_->SynchronizeAllActivity()) { - return InternalError("Failed to synchronize GPU for autotuning."); - } - - // Create a stream for us to do our work on. - se::Stream stream{stream_exec_}; - stream.Init(); - const auto device_ordinal = stream_exec_->device_ordinal(); - - // allocator either points to this->allocator_ or, if that's null, to a - // StreamExecutorMemoryAllocator for stream_exec_. - DeviceMemoryAllocator* allocator; - optional se_allocator; - if (allocator_ != nullptr) { - allocator = allocator_; - } else { - se_allocator.emplace(stream_exec_->platform(), - absl::Span({stream_exec_})); - allocator = &*se_allocator; - } - - const auto initialize_buffer = [&stream, cross_check_enabled]( - DeviceMemoryBase buffer) { - if (cross_check_enabled) { - // Broadcast a constant to the buffer, instead of zeroing the buffer. A - // non-zero constant is useful for the cross checking, because zero-inputs - // may not always reveal the bugs. - CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); - size_t left_over_bytes = buffer.size() % 4; - CHECK_EQ(0, left_over_bytes % 2); - - constexpr float kBroadcastedConstant = 0.1f; - static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), - Eigen::half(kBroadcastedConstant)}; - uint32 bits; - static_assert(sizeof(bits) == sizeof(halfs), ""); - memcpy(&bits, halfs, sizeof(bits)); - - size_t aligned_size = buffer.size() / 4 * 4; - stream.ThenMemset32(&buffer, bits, aligned_size); - - DeviceMemoryBase left_over( - static_cast(buffer.opaque()) + aligned_size, left_over_bytes); - stream.ThenMemcpy(&left_over, halfs, left_over_bytes); - } else { - // Although we don't have evidence this matters, zero out the buffers - // before autotuning. It's conceivable that using uninitialized memory as - // the inputs might affect performance if e.g. the inputs contain - // denormals, and this is easy enough. - stream.ThenMemZero(&buffer, buffer.size()); - } - }; - - // Allocate space for the input, filter, and output of the convolution. We - // use a ScratchAllocator for this instead of calling allocator_ directly so - // that our allocations don't leak. - ScratchAllocator input_output_allocator(device_ordinal, allocator); - std::vector operand_buffers; - for (const auto* operand : instr->operands()) { - TF_ASSIGN_OR_RETURN(auto buffer, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(operand->shape()))); - initialize_buffer(buffer); - operand_buffers.push_back(buffer); - } - TF_ASSIGN_OR_RETURN( - auto result_buffer, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); - initialize_buffer(result_buffer); - - se::dnn::ProfileResult best_result; - int64 best_result_bytes_used = 0; - TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); - - optional comparator; - // Use the first algorithm that's supported as reference. There isn't a - // particular reason to use it, as any algorithm sufficies. It doesn't make - // this algorithm considered correct, though. - optional first_algorithm; - TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); - for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { - ScratchAllocator scratch_allocator(device_ordinal, allocator); - se::dnn::ProfileResult profile_result; - VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " - << instr->ToString(); - - backend_config.set_algorithm(alg.algo_id()); - backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); - TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); - bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers), - result_buffer, &scratch_allocator, - &stream, &profile_result) - .ok(); - - if (launch_ok && profile_result.is_valid()) { - const bool crash_on_checking_failure = - instr->GetModule() - ->config() - .debug_options() - .xla_gpu_crash_on_verification_failures(); - if (comparator.has_value()) { - StatusOr result = comparator->CompareEqual( - se::DeviceMemory(result_buffer)); - if (!result.ok()) { - LOG(ERROR) << "Unable to compare " - << AlgorithmToString(*first_algorithm) << " against " - << AlgorithmToString(alg) << " for " << instr->ToString() - << ": " << result.status(); - CHECK(!crash_on_checking_failure); - } else if (!result.ValueOrDie()) { - LOG(ERROR) << "Results mismatch between different convolution " - "algorithms. This is likely a bug in convolution, or " - "an excessive loss of precision in convolution. " - << instr->ToString() << " for " - << AlgorithmToString(*first_algorithm) << " vs " - << AlgorithmToString(alg); - CHECK(!crash_on_checking_failure); - } - } else if (cross_check_enabled) { - auto comp = F16BufferComparator::Create( - se::DeviceMemory(result_buffer), compiler_, allocator, - &stream); - if (comp.ok()) { - comparator.emplace(comp.ConsumeValueOrDie()); - first_algorithm.emplace(alg); - } else { - LOG(ERROR) << "Fail to initialize buffer comparator: " - << comp.status() << ", instruction: " << instr->ToString(); - CHECK(!crash_on_checking_failure); - } - } - int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) - << " succeeded, taking " << profile_result.elapsed_time_in_ms() - << "ms and using " << NumBytesToString(scratch_bytes_used) - << " of scratch (Best result: " - << best_result.elapsed_time_in_ms() << "ms, " - << NumBytesToString(best_result_bytes_used) << " of scratch)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - best_result_bytes_used = scratch_bytes_used; - } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; - } - } - if (best_result.is_valid()) { - VLOG(2) << "Best algorithm for " << instr->ToString() << ": " - << AlgorithmToString(best_result.algorithm()) << ", takes " - << best_result.elapsed_time_in_ms() << "ms, and uses " - << best_result_bytes_used << "B of scratch memory."; - return std::make_tuple(best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used); - } - - return InternalError( - "All algorithms tried for convolution %s failed. Falling back to " - "default algorithm.", - instr->ToString()); -} - -StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( - HloInstruction* instr) { - CHECK(IsCustomCallToDnnConvolution(*instr)); - - StatusOr> alg_scratch_and_tc = - PickBestAlgorithm(Cast(instr)); - - if (!alg_scratch_and_tc.ok()) { - LOG(ERROR) << alg_scratch_and_tc.status(); - return false; - } - - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = - alg_scratch_and_tc.ConsumeValueOrDie(); - - VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " - << NumBytesToString(scratch_bytes) - << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << tensor_ops_enabled; - - // Replace instr with a new CustomCall which has the correct algorithm, and - // whose output shape has the appropriate amount of scratch memory. - HloComputation* computation = instr->parent(); - Shape new_call_shape = - ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {scratch_bytes})}); - - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - instr->backend_config()); - backend_config.set_algorithm(algorithm); - backend_config.set_tensor_ops_enabled(tensor_ops_enabled); - - HloInstruction* new_call = computation->AddInstruction( - instr->CloneWithNewOperands(new_call_shape, instr->operands())); - - TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); - - // Repackage new_call so it has the same shape as the original call, namely - // (conv_result, u8[0]). - HloInstruction* new_tuple = - computation->AddInstruction(HloInstruction::CreateTuple( - {computation->AddInstruction(HloInstruction::CreateGetTupleElement( - new_call_shape.tuple_shapes(0), new_call, 0)), - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({})))})); - - TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); - return true; -} - -StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( - HloComputation* computation) { - std::vector convs; - for (auto* instr : computation->instructions()) { - if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); - } - } - - bool changed = false; - for (auto* instr : convs) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); - changed |= result; - } - return changed; -} - -StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { - bool changed = false; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); - changed |= result; - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h deleted file mode 100644 index aeda2fc7f8b4d6169fc2baa8975119ba7bf68dd2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ - -#include "absl/types/optional.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace xla { -namespace gpu { - -// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for -// each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloModulePass { - public: - // If the `allocator` parameter is not null, we will use it to allocate temp - // memory while timing the various convolution algorithms. If it's null, - // we'll use the default allocator on the StreamExecutor. - CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator, - Compiler* compiler) - : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - - absl::string_view name() const override { - return "cudnn-convolution-algorithm-picker"; - } - - StatusOr Run(HloModule* module) override; - - private: - StatusOr RunOnComputation(HloComputation* computation); - StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr> PickBestAlgorithm( - HloCustomCallInstruction* instr); - - se::StreamExecutor* stream_exec_; // never null - DeviceMemoryAllocator* allocator_; // may be null - Compiler* compiler_; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc deleted file mode 100644 index ef292373018295f5100b91c343df274b626c2fa1..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ /dev/null @@ -1,565 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace gpu { - -namespace { - -HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, - HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - HloComputation* computation = lhs->parent(); - - // This call returns a tuple of (conv_result, scratch_memory), where - // conv_result is the actual result of the convolution, and scratch_memory is - // temporary memory used by cudnn. - // - // At the moment, we don't know how much scratch memory this conv is going to - // use, so we put u8[0] in this place. Later on another pass will choose - // which conv algorithm to use, and at that point we'll modify the shape of - // this second tuple element. - Shape call_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - - HloInstruction* custom_call = computation->AddInstruction( - HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); - custom_call->set_window(window); - custom_call->set_convolution_dimension_numbers(dnums); - custom_call->set_feature_group_count(feature_group_count); - return custom_call; -} - -bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { - const ConvolutionDimensionNumbers& dnums = - conv->convolution_dimension_numbers(); - if (dnums.input_spatial_dimensions_size() > 3) { - return false; - } - - // CuDNN does not accept zero-element arguments - if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || - ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { - return false; - } - - if (window_util::HasWindowReversal(conv->window())) { - return false; - } - return true; -} - -// Try to match a backward filter pattern that contains "conv". -// Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardFilter( - HloInstruction* conv) { - const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - if (conv->feature_group_count() > 1) { - return no_match_result; - } - // Step 1: match the instruction pattern without considering the paddings and - // dimension numbers just yet. We may need some generic pattern matcher - // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h - // - // Backward filter convolution is implemented in XLA as the forward - // convolution of padded activations and dilated gradients. Padding on - // activations and dilation on gradients are specified in the "window" field - // of the forward convolution. - // - // activations gradients - // \ / - // v v - // Convolution - // conv - CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - - // Step 2: match paddings and dimension numbers of the forward convolution. - const ConvolutionDimensionNumbers& conv_dnums = - conv->convolution_dimension_numbers(); - auto input_batch_dim = conv_dnums.input_batch_dimension(); - auto input_feature_dim = conv_dnums.input_feature_dimension(); - auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); - auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); - auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); - auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); - auto output_batch_dim = conv_dnums.output_batch_dimension(); - auto output_feature_dim = conv_dnums.output_feature_dimension(); - auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); - - for (const WindowDimension& window_dim : conv->window().dimensions()) { - if (window_dim.stride() != 1) { - VLOG(1) << "Forward convolution's window " - << conv->window().ShortDebugString() - << " should have stride of 1."; - return no_match_result; - } - if (window_dim.base_dilation() != 1) { - VLOG(1) << "Forward convolution's window " - << conv->window().ShortDebugString() - << " should have no base (LHS) dilation."; - return no_match_result; - } - if (window_dim.padding_low() < 0) { - VLOG(1) << "Padding low should be non-negative."; - return no_match_result; - } - if (window_dim.window_reversal()) { - VLOG(1) << "Window reversal field not supported"; - return no_match_result; - } - // Padding high will be checked in Step 3. - } - if (input_batch_dim == output_batch_dim && - !window_util::HasWindowDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward filter convolution."; - return no_match_result; - } - - // Step 3: fuse the matched HLOs into a backward convolution instruction. - // - // Compute the window of the backward convolution. - Window backward_conv_window; - for (int i = 0; i < input_spatial_dims.size(); ++i) { - WindowDimension* dim = backward_conv_window.add_dimensions(); - // The window size of the backward convolution equals the output size of the - // forward convolution. - int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); - dim->set_size(filter_size); - // The window stride equals the window dilation of the forward convolution. - dim->set_stride(conv->window().dimensions(i).window_dilation()); - // The window's low padding is the same as the low padding of the - // activations. - dim->set_padding_low(conv->window().dimensions(i).padding_low()); - - int64 input_size = - conv->operand(0)->shape().dimensions(input_spatial_dims[i]); - int64 output_size = conv->window().dimensions(i).size(); - // Compute the range of the amount of valid high padding. We first compute - // min_padding_high, the amount of padding on the right/bottom to ensure the - // last patch ends at the border, i.e., - // - // input_size + dim->padding_low() + min_padding_high - // = (output_size - 1) * stride + filter_size - // - // Because convolution ignores trailing incomplete windows, any amount of - // padding high from min_padding_high to min_padding_high+stride-1 - // (max_padding_high) has the same effect. - int64 padded_input_size = filter_size + (output_size - 1) * dim->stride(); - int64 min_padding_high = - padded_input_size - input_size - dim->padding_low(); - int64 max_padding_high = min_padding_high + dim->stride() - 1; - CHECK_GE(dim->padding_low(), 0); - // In practice, since cuDNN convolution only supports even padding, we make - // the amount of high padding the same as the amount of low padding as long - // as it is between min_padding_high and max_padding_high. If it is not in - // that range, we pick the one that's closest to dim->padding_low() and let - // PadInsertion canonicalize the resultant backward convolution later. - // Picking the closest one minimizes the cost of the kPad instruction to be - // inserted by PadInsertion. - if (dim->padding_low() >= min_padding_high && - dim->padding_low() <= max_padding_high) { - dim->set_padding_high(dim->padding_low()); - } else { - if (dim->padding_low() < min_padding_high) { - dim->set_padding_high(min_padding_high); - } else { - dim->set_padding_high(max_padding_high); - } - } - if (dim->padding_high() < 0) { - LOG(ERROR) - << "Fusing this pattern to backward filter convolution would cause " - "negative padding (" - << dim->padding_high() - << ") on right/bottom of the weight gradients, which is not " - "supported by PadInsertion (b/32744257). Falling back to " - "unfused convolution for instruction: " - << conv->ToString(); - return no_match_result; - } - } - - // Restore the dimension numbers of the backward convolution from the forward - // convolution. The two activation dimensions are reversed (batch and - // feature). - ConvolutionDimensionNumbers backward_conv_dnums; - backward_conv_dnums.set_input_batch_dimension(input_feature_dim); - backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - for (int i = 0; i < input_spatial_dims.size(); ++i) { - backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); - } - backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); - backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); - for (int i = 0; i < kernel_spatial_dims.size(); ++i) { - backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); - } - // The dimension numbering of the output of the forward convolution (before - // transposition) is the same as that of the activations (according to the - // semantics of kConvolution). The batch dimension of the activations should - // be treated as the input feature dimension, and the feature dimension should - // be treated as the output feature. - backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); - backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); - for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); - } - - return std::make_tuple(true, backward_conv_window, backward_conv_dnums); -} - -// Try to match a backward input pattern that contains "conv". -// Precondition: "conv" is a kConvolution. -std::tuple -MatchBackwardInput(HloInstruction* conv) { - const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - - // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also - // for the backward input convolution, but at least for now with version 7.1.4 - // it is slower. This needs to be re-evaluated for future cuDNN versions. - // Note that we already have the necessary code down below, the only thing to - // enable it is to remove the following early return. - if (conv->feature_group_count() > 1) { - return no_match_result; - } - - // Match instruction pattern. - CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - HloInstruction* reverse_filter = conv->mutable_operand(1); - ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - - // We pattern-match to a backwards input conv if: - // - // - all spatial dims of the filter are reversed - // - // OR - // - // - filter is 1x1 or a constant AND - // - conv has base dilation (otherwise this is just a regular forward conv). - // - // The final criterion above is just for canonicalization; cudnn seems to run - // just as fast if we canonicalize 1x1/constant filters without base dilation - // to forward or backward convs. We canonicalize to forward conv because (a) - // it's more natural (constant filters usually show up when doing inference, - // and having backwards convolutions in inference graphs would be weird), and - // (b) cudnn has special fusions for forward conv plus bias and activation, - // and we want to pattern-match to that after running this pass. - bool is_reversed_filter = - reverse_filter->opcode() == HloOpcode::kReverse && - absl::c_is_permutation(dnums.kernel_spatial_dimensions(), - reverse_filter->dimensions()); - bool is_1x1_filter = - absl::c_all_of(conv->window().dimensions(), - [](const WindowDimension& d) { return d.size() == 1; }); - if (!is_reversed_filter && - !(window_util::HasBaseDilation(conv->window()) && - (reverse_filter->IsConstant() || is_1x1_filter))) { - VLOG(1) << "Can't match to backwards convolution. Either filter is not " - "kReverse, or it's not a base-dilated conv with a 1x1 or " - "constant filter."; - return no_match_result; - } - - // Match padding and dilation of the forward convolution. - for (const WindowDimension& window_dim : conv->window().dimensions()) { - if (window_dim.stride() != 1) { - VLOG(1) << "Forward convolution's window " - << conv->window().ShortDebugString() - << " should have stride of 1."; - return no_match_result; - } - if (window_dim.window_dilation() != 1) { - VLOG(1) << "Forward convolution's window " - << conv->window().ShortDebugString() - << " should have no window dilation."; - return no_match_result; - } - if (window_dim.window_reversal()) { - VLOG(1) << "Window reversal field not supported"; - return no_match_result; - } - } - - const auto& input_spatial_dims = dnums.input_spatial_dimensions(); - const auto& output_spatial_dims = dnums.output_spatial_dimensions(); - CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); - CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); - - const Window& old_window = conv->window(); - Window new_window = old_window; - for (size_t i = 0; i < input_spatial_dims.size(); ++i) { - // Restore backward convolution's padding config from the matched pattern. - // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc - // for how we convert backward input convolution to a variant of forward - // convolution. - // - // The stride of the backward convolution - // = the base dilation factor of the forward convolution - auto dim = new_window.mutable_dimensions(i); - dim->set_stride(old_window.dimensions(i).base_dilation()); - - // The low padding = kernel_size - 1 - low padding on the gradients - // Make sure the low padding is not negative. - auto kernel_size = old_window.dimensions(i).size(); - auto backward_padding_low = - kernel_size - 1 - old_window.dimensions(i).padding_low(); - if (backward_padding_low < 0) { - LOG(ERROR) - << "The low padding of the backward convolution would be negative (" - << backward_padding_low - << "), which isn't supported by PadInsertion for now (b/32744257)."; - return no_match_result; - } - dim->set_padding_low(backward_padding_low); - - // Compute the range of the amount of padding on the right/bottom of the - // activations. XLA's convolution requires all patches to be within the - // padded base. This gives us flexiblity to choose the amount of high - // padding from a set of values without changing the result of the backward - // convolution. The minimum amount (min_padding_high) makes the last patch - // end at the border. The maximum amount (max_padding_high) equals - // min_padding_high+stride-1 -- max_padding_high+1 would cause the output - // size to change. - auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); - auto output_size = - conv->operand(0)->shape().dimensions(input_spatial_dims[i]); - auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); - auto total_pad_size = padded_input_size - unpadded_input_size; - auto min_padding_high = total_pad_size - backward_padding_low; - auto max_padding_high = min_padding_high + dim->stride() - 1; - - if (backward_padding_low >= min_padding_high && - backward_padding_low <= max_padding_high) { - // In the best case (most likely), if backward_padding_low is in the range - // of the amounts of valid high padding, we choose backward_padding_low - // because cuDNN supports even padding only. - dim->set_padding_high(backward_padding_low); - } else { - // Otherwise, we choose the amount that's closest to backward_padding_low, - // and PadInsertion will later insert kSlice instructions to enforce even - // padding. - // - // For example, consider the backward convolution pattern - // - // ab xy - // | pad | reverse - // .a.b yx - // \ / - // ABC - // - // The amount of low padding on activations (in backward convolution) is - // backward_padding_low = kernel_size - 1 - forward_padding_low - // = 2 - 1 - 1 = 0 - // - // The amount of padding high must be between 1 and 2, in order to make - // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in - // the range of [1,2], so we pick the closest valid amount of padding - // high, which is 1 in this case. Therefore, we fuse the above pattern to - // - // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1) - if (backward_padding_low < min_padding_high) { - dim->set_padding_high(min_padding_high); - } else { - dim->set_padding_high(max_padding_high); - } - } - // PadInsertion doesn't handle backward input convolution with negative - // padding for now. So fall back to unfused convolution in case of negative - // padding. For example, - // ABCD = Conv(abc, reverse(xy), padding_high=2) - // could be fused to - // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1) - // with positive padding low but negative padding high. - if (dim->padding_high() < 0) { - LOG(ERROR) << "Fusing this pattern to backward convolution would cause " - "negative padding (" - << dim->padding_high() - << ") on right/bottom of the activations, which is not " - "supported by PadInsertion (b/32744257). Falling back to " - "unfused convolution for instruction: " - << conv->ToString(); - return no_match_result; - } - } - - // OK, it's a match! Switch the input feature dimension with the output - // feature dimension. This is the way cuDNN expects it to be. - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - - // If we matched against a constant, we need to add a reverse op that can be - // subsumed by the cuDNN call. algebraic-simplifier will later remove any - // unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse && - reverse_filter->IsConstant()) { - // Create a double-reverse, which is a nop. - HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( - reverse_filter->shape(), reverse_filter, - AsInt64Slice(dnums.kernel_spatial_dimensions()))); - reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( - reverse_filter->shape(), reverse_filter, - AsInt64Slice(dnums.kernel_spatial_dimensions()))); - TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); - } - - // Calculate the 'rhs' that goes into the backward input convolution. - HloInstruction* rhs = reverse_filter; - // One reverse is subsumed by the cuDNN call. - if (rhs->opcode() == HloOpcode::kReverse) { - rhs = rhs->mutable_operand(0); - } - if (conv->feature_group_count() == 1) { - return std::make_tuple(true, new_window, dnums, rhs); - } - - // Handle grouped convolutions. Because we swapped the input feature dimension - // with the output feature dimension, we need to also reshape the kernel so - // that the 'feature_group_count' parameter still makes sense. The - // 'feature_group_count' parameter essentially specifies how often the - // 'kernel_input_feature_dimension' is repeated. So when we swap these - // dimensions, we need to divide the new 'kernel_input_feature_dimension' by - // 'feature_group_count' and multiply the new - // 'kernel_output_feature_dimension' by 'feature_group_count'. - Shape new_shape = rhs->shape(); - int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); - int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); - - // In the backward convolution case, the spatial dimensions become the - // feature dimensions, and we are guaranteed that the spatial dimensions are - // adjacent. - CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); - int64 input_features = new_shape.dimensions(input_feature_dimension); - int64 output_features = new_shape.dimensions(output_feature_dimension); - new_shape.set_dimensions(input_feature_dimension, - input_features / conv->feature_group_count()); - new_shape.set_dimensions(output_feature_dimension, - output_features * conv->feature_group_count()); - HloComputation* c = conv->parent(); - rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); - return std::make_tuple(true, new_window, dnums, rhs); -} - -CudnnConvBackendConfig GetDefaultBackendConfig() { - CudnnConvBackendConfig config; - config.set_conv_result_scale(1); - return config; -} - -// Tries to rewrite a single convolution into a call to cudnn. -StatusOr RunOnInstruction(HloInstruction* conv) { - CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - - HloInstruction* custom_call = [&]() -> HloInstruction* { - bool match; - Window window; - ConvolutionDimensionNumbers dnums; - HloInstruction* rhs; - - std::tie(match, window, dnums) = MatchBackwardFilter(conv); - if (match) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), - conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count()); - } - - std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); - if (match) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), - conv->mutable_operand(0), rhs, window, dnums, - conv->feature_group_count()); - } - - // If all else fails, try a forward convolution. - if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConv( - kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers(), conv->feature_group_count()); - } - - return nullptr; - }(); - - if (custom_call == nullptr) { - return false; - } - - TF_RETURN_IF_ERROR( - custom_call->set_backend_config(GetDefaultBackendConfig())); - - // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out - // the conv result and replace `conv` with it. - TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( - conv, - HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); - return true; -} - -// Rewrites the convolutions in the given computation into calls to cudnn. -// Returns true if it made any changes. -StatusOr RunOnComputation(HloComputation* computation) { - std::vector convs; - for (auto* hlo : computation->instructions()) { - if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo); - } - } - - bool changed = false; - for (HloInstruction* conv : convs) { - TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); - changed |= result; - } - return changed; -} -} // namespace - -StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { - bool changed = false; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); - changed |= result; - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h deleted file mode 100644 index 8d7c6fdab510407428a115579a90e8cf85e9fad2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// Rewrites plain convolutions, backwards-filter convolutions, and -// backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloModulePass { - public: - absl::string_view name() const override { - return "cudnn-convolution-rewriter"; - } - - StatusOr Run(HloModule* module) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc deleted file mode 100644 index d237f8930b74d460ad3d4602670a5afb19b496a2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ /dev/null @@ -1,615 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" - -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -namespace op = xla::testing::opcode_matchers; -using ::testing::_; - -class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { - public: - CudnnConvolutionRewriterTest() - : HloVerifiedTestBase(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false) { - for (int i = 0; i < 2; ++i) { - WindowDimension* window_dim = default_conv_window_.add_dimensions(); - window_dim->set_size(1); - window_dim->set_stride(1); - window_dim->set_padding_low(0); - window_dim->set_padding_high(0); - window_dim->set_window_dilation(1); - window_dim->set_base_dilation(1); - } - // TF data shapes are by default in the NHWC order, and filter shape is by - // default in HWIO order. For backward filter convolution, we need to swap - // the batch and feature dimension in the activations, and treat the batch - // dimension in gradients as the input feature dimension in the filter. - // - // TODO(jingyue): Add more tests on NCHW input order, which TF also - // supports. - tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); - tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( - 3); - tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); - - tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); - tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); - tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); - tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); - tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); - tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); - tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); - tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); - tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); - tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); - tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); - tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1); - } - - protected: - bool RunPass(HloModule* module) { - return CudnnConvolutionRewriter().Run(module).ValueOrDie(); - } - - // A convolution window with stride 1 and zero padding. The size fields are - // not set. - Window default_conv_window_; - ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_; - ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; -}; - -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { - HloComputation::Builder builder(TestName()); - HloInstruction* activations = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); - HloInstruction* gradients = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients")); - Window conv_window = default_conv_window_; - conv_window.mutable_dimensions(1)->set_size(2); - conv_window.mutable_dimensions(1)->set_window_dilation(2); - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) - .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); -} - -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveEquivalentToForwardConvolution) { - HloComputation::Builder builder(TestName()); - HloInstruction* activations = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations")); - HloInstruction* gradients = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients")); - Window conv_window = default_conv_window_; - conv_window.mutable_dimensions(1)->set_size(3); - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - activations->shape(), gradients->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_filter_) - .ConsumeValueOrDie(), - activations, gradients, /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); -} - -// Extracted from block35 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedActivations) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* activations = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); - HloInstruction* gradients = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(35); - conv_window.mutable_dimensions(i)->set_padding_low(1); - conv_window.mutable_dimensions(i)->set_padding_high(1); - } - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); -} - -// Extracted from inception v3 training. -TEST_F(CudnnConvolutionRewriterTest, - BackwardFilterConvolveWithPaddedGradients) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* activations = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations")); - HloInstruction* gradients = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients")); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(4); - conv_window.mutable_dimensions(i)->set_padding_high(-1); - conv_window.mutable_dimensions(i)->set_window_dilation(2); - } - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); -} - -TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* activations = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations")); - HloInstruction* gradients = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients")); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(35); - // Uneven padding: padding_low=0, padding_high=1 - conv_window.mutable_dimensions(i)->set_padding_high(1); - } - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); -} - -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output")); - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel")); - HloInstruction* reverse_kernel = builder.AddInstruction( - HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3})); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(7); - conv_window.mutable_dimensions(i)->set_padding_low(3); - conv_window.mutable_dimensions(i)->set_padding_high(3); - } - ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_input_batch_dimension(0); - conv_dnums.set_output_batch_dimension(0); - conv_dnums.set_input_feature_dimension(1); - conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_input_spatial_dimensions(2); - conv_dnums.add_output_spatial_dimensions(2); - conv_dnums.add_input_spatial_dimensions(3); - conv_dnums.add_output_spatial_dimensions(3); - conv_dnums.set_kernel_input_feature_dimension(0); - conv_dnums.set_kernel_output_feature_dimension(1); - conv_dnums.add_kernel_spatial_dimensions(2); - conv_dnums.add_kernel_spatial_dimensions(3); - - HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, - conv_dnums, DefaultPrecisionConfig(2))); - // Verify the convolution's shape is consistent with ShapeInference. - CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), - /*feature_group_count=*/1, conv_window, conv_dnums) - .ValueOrDie())); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - - ASSERT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); - const HloInstruction* custom_call = - entry_computation->root_instruction()->operand(0); - for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = custom_call->window().dimensions(i); - // Low padding of the backward input convolution - // = kernel_size - 1 - low padding on gradients. - EXPECT_EQ(3, window_dim.padding_low()); - EXPECT_EQ(3, window_dim.padding_high()); - EXPECT_EQ(1, window_dim.stride()); - } -} - -// Convolve([abc], [x], base_dilation=2) -// = Convolve([abc], Reverse([x]), base_dilation=2) -// = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { - auto builder = HloComputation::Builder(TestName()); - // NHWC dimension order. - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); - // HWOI dimension order. - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); - - Window conv_window = default_conv_window_; - conv_window.mutable_dimensions(1)->set_base_dilation(2); - - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_) - .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); -} - -// BackwardInputConvolve([abc], [x], stride=1) is equivalent to -// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input -// convolution. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { - auto builder = HloComputation::Builder(TestName()); - // NHWC dimension order. - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); - // HWOI dimension order. - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); - - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), kernel->shape(), /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_) - .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, - default_conv_window_, tf_default_dnums_for_backward_input_, - DefaultPrecisionConfig(2))); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT( - entry_computation->root_instruction(), - op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); -} - -// Extracted from Inception V3 training. -// -// filter(HWIO) -// 3x3x192x320 -// | -// v -// gradients(NHWC) reverse -// 20x4x4x320 3x3x192x320 -// \ / -// \ / -// conv (NHWC) with padding (low=2,high=3,interior=1) -// 20x10x10x192 -// -// Gradients are padded unevenly. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnGradients) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); - HloInstruction* reverse_kernel = builder.AddInstruction( - HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(3); - conv_window.mutable_dimensions(i)->set_padding_low(2); - conv_window.mutable_dimensions(i)->set_padding_high(3); - // Interior padding = 1. - conv_window.mutable_dimensions(i)->set_base_dilation(2); - } - HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - // Verify the convolution's shape is consistent with ShapeInference. - CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - ASSERT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); - const HloInstruction* custom_call = - entry_computation->root_instruction()->operand(0); - for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = custom_call->window().dimensions(i); - EXPECT_EQ(0, window_dim.padding_low()); - EXPECT_EQ(0, window_dim.padding_high()); - EXPECT_EQ(2, window_dim.stride()); - } -} - -// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the -// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { - auto builder = HloComputation::Builder(TestName()); - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output")); - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel")); - HloInstruction* reverse_kernel = builder.AddInstruction( - HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); - - Window conv_window = default_conv_window_; - for (int i = 0; i < 2; ++i) { - conv_window.mutable_dimensions(i)->set_size(3); - conv_window.mutable_dimensions(i)->set_padding_low(3); - conv_window.mutable_dimensions(i)->set_padding_high(2); - conv_window.mutable_dimensions(i)->set_base_dilation(2); - } - HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - // Verify the convolution's shape is consistent with ShapeInference. - CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT( - entry_computation->root_instruction(), - op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); -} - -// Extracted from //learning/brain/google/xla/benchmarks/resnet.py -// -// For simplicity, we focus on the column dimension and ignore other dimensions. -// We use [?] to represent the shape instead of the content. -// -// Suppose operator FC does -// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame -// -// BC = BackwardInput(FC) does: -// [14] = conv([7], reverse([3]), -// padding_low=2, padding_high=1, base_dilation=2) -// -// We should fuse BC even though padding on activations is uneven, because -// PadInsertion will canonicalize the fusion HLO. -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveUnevenPaddingOnActivations) { - auto builder = HloComputation::Builder(TestName()); - // The gradients are in NCHW layout. - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output")); - // The kernel is in HWIO layout. - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel")); - HloInstruction* reverse_kernel = builder.AddInstruction( - HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); - - Window conv_window = default_conv_window_; - WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); - forward_conv_col_dim->set_size(3); - forward_conv_col_dim->set_padding_low(2); - forward_conv_col_dim->set_padding_high(1); - forward_conv_col_dim->set_base_dilation(2); - HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - // Verify the convolution's shape is consistent with ShapeInference. - CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); - - auto module = CreateNewModule(); - const HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - ASSERT_THAT(entry_computation->root_instruction(), - op::GetTupleElement( - op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); - const WindowDimension& backward_conv_col_dim = - entry_computation->root_instruction()->operand(0)->window().dimensions(1); - EXPECT_EQ(0, backward_conv_col_dim.padding_low()); - EXPECT_EQ(1, backward_conv_col_dim.padding_high()); -} - -// For simplicity, we focus on the column dimension and ignore other dimensions. -// We use [?] to represent the shape instead of the content. -// -// Suppose operator FC does -// [3] = conv([4], [2], padding_low=1, padding_high=-1) -// -// BC = BackwardInput(FC) does: -// [4] = conv([3], reverse([2]), padding_high=2) -// -// We currently don't fuse BC because PadInsertion doesn't support negative -// padding on the gradients of backward convolution (b/32744257). -TEST_F(CudnnConvolutionRewriterTest, - BackwardInputConvolveNegativePaddingHighOnActivations) { - auto builder = HloComputation::Builder(TestName()); - // The gradients are in NCHW layout. - HloInstruction* output = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output")); - // The kernel is in HWIO layout. - HloInstruction* kernel = - builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel")); - HloInstruction* reverse_kernel = builder.AddInstruction( - HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1})); - - Window conv_window = default_conv_window_; - WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1); - forward_conv_col_dim->set_size(2); - forward_conv_col_dim->set_padding_high(2); - HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - /*feature_group_count=*/1, conv_window, - tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); - // Verify the convolution's shape is consistent with ShapeInference. - CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, - conv_window, tf_default_dnums_for_backward_input_) - .ValueOrDie())); - - auto module = CreateNewModule(); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module)); - EXPECT_THAT( - entry_computation->root_instruction(), - op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); -} - -// Check that we will materialize a reversed version of a constant in order to -// pattern-match a backwards input convolution. -TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { - Array4D constant_arr(4, 4, 2, 2); - constant_arr.FillIota(0); - string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); - ParseAndVerifyModule(absl::StrFormat(R"( - HloModule test - - ENTRY entry_computation { - param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) - constant = f32[4,4,2,2]{3,2,1,0} constant(%s) - ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), - window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, - dim_labels=bf01_01oi->bf01, feature_group_count=1 - })", - constant_str)); - EXPECT_TRUE(RunPass(&module())); - EXPECT_THAT( - module().entry_computation()->root_instruction(), - op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, - op::Reverse(op::Constant())), - 0)); -} - -} // anonymous namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc deleted file mode 100644 index 89dd1bb272663ac1f6eecbaae070d201d38e44c8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ /dev/null @@ -1,419 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace gpu { -namespace { - -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::Stream; -using se::dnn::AlgorithmConfig; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::DimIndex; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; -using se::dnn::ProfileResult; - -struct CudnnConvParams { - // Here are the fields related to cuDNN's fused convolution. The result thus - // is defined as: - // activation(conv_result_scale * conv(x, w) + - // side_input_scale * side_input + broadcast(bias)) - // - // The most common fused conv is conv forward + relu/identity, for example. - // - // bias_buf is a single-dimensional array, with the length equal to the number - // of output features. It'll be broadcasted to the output shape in order to be - // added to the final results. - // - // side_input_buf, if valid, must have the same shape as the output buffer. - struct FusionParams { - se::dnn::ActivationMode mode; - double side_input_scale; - se::DeviceMemoryBase bias_buf; - se::DeviceMemoryBase side_input_buf; // nullable - }; - - CudnnConvKind kind; - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; - se::DeviceMemoryBase input_buf; - se::DeviceMemoryBase filter_buf; - se::DeviceMemoryBase output_buf; - const Window* window; - const ConvolutionDimensionNumbers* dnums; - int64 feature_group_count; - se::dnn::AlgorithmConfig algorithm; - double conv_result_scale; - - absl::optional fusion; -}; - -// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, -// returning it (in its entirety) the first time Allocate() is called. -class ScratchBufAllocator : public se::ScratchAllocator { - public: - explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) - : scratch_(scratch) {} - - ~ScratchBufAllocator() override = default; - - int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { - return scratch_.size(); - } - - se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { - if (allocated_) { - return se::port::InternalError( - "Can't allocate twice from a ScratchBufAllocator."); - } - if (byte_size > scratch_.size()) { - return se::port::InternalError(absl::StrCat( - "Can't allocate ", byte_size, - " bytes from a ScratchBufAllocator of size ", scratch_.size())); - } - - allocated_ = true; - return se::DeviceMemory(scratch_); - } - - private: - se::DeviceMemoryBase scratch_; - bool allocated_ = false; -}; - -template -Status RunCudnnConvolutionImpl(CudnnConvParams params, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - CudnnConvKind kind = params.kind; - const Shape& input_shape = *params.input_shape; - const Shape& filter_shape = *params.filter_shape; - const Shape& output_shape = *params.output_shape; - DeviceMemory input_buf(params.input_buf); - DeviceMemory filter_buf(params.filter_buf); - DeviceMemory output_buf(params.output_buf); - const Window& window = *params.window; - const ConvolutionDimensionNumbers& dnums = *params.dnums; - int64 feature_group_count = params.feature_group_count; - AlgorithmConfig algorithm = params.algorithm; - - VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); - VLOG(3) << "tensor_ops_enabled: " - << algorithm.algorithm().tensor_ops_enabled(); - VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); - VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); - VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); - VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; - - const int num_dimensions = window.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(primitive_util::NativeToPrimitiveType(), - output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - - CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - DataLayout input_dl; - FilterLayout filter_dl; - DataLayout output_dl; - - TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), - XlaConvLayoutsToStreamExecutorLayouts( - dnums, input_shape.layout(), filter_shape.layout(), - output_shape.layout())); - - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(input_dl) - .set_feature_map_count( - input_shape.dimensions(dnums.input_feature_dimension())) - .set_count(input_shape.dimensions(dnums.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape.dimensions(dnums.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(filter_dl) - .set_input_feature_map_count( - filter_shape.dimensions(dnums.kernel_input_feature_dimension())) - .set_output_feature_map_count( - filter_shape.dimensions(dnums.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - convolution_descriptor.set_group_count(feature_group_count); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(output_dl) - .set_feature_map_count( - output_shape.dimensions(dnums.output_feature_dimension())) - .set_count(output_shape.dimensions(dnums.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape.dimensions(dnums.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor.set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - - switch (kind) { - case CudnnConvKind::kForward: - if (params.conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); - } - stream->ThenConvolveWithAlgorithm( - input_descriptor, input_buf, filter_descriptor, filter_buf, - convolution_descriptor, output_descriptor, &output_buf, - scratch_allocator, algorithm, profile_result); - break; - case CudnnConvKind::kBackwardInput: - if (params.conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); - } - stream->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_buf, output_descriptor, output_buf, - convolution_descriptor, input_descriptor, &input_buf, - scratch_allocator, algorithm, profile_result); - break; - case CudnnConvKind::kBackwardFilter: - if (params.conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); - } - stream->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_buf, output_descriptor, output_buf, - convolution_descriptor, filter_descriptor, &filter_buf, - scratch_allocator, algorithm, profile_result); - break; - case CudnnConvKind::kForwardActivation: { - BatchDescriptor bias_desc; - bias_desc.set_count(1) - .set_height(1) - .set_width(1) - .set_feature_map_count( - output_shape.dimensions(dnums.output_feature_dimension())) - .set_layout(output_dl); - - se::DeviceMemory side_input(params.fusion->side_input_buf); - // If there is no side input, use output as the side input. - if (side_input.is_null()) { - if (params.fusion->side_input_scale != 0) { - return InternalError( - "Side input scale is not 0, yet no side input buffer is " - "provided"); - } - // Since side-input scale is 0, the values in the side input don't - // matter. The simplest thing to do would be to pass in a null buffer - // for the side input, but cudnn doesn't allow this. cudnn does promise - // that if side-input-scale is 0 the side input won't be read, so we - // just pass in the output buffer, since it's handy and has the correct - // size. - side_input = output_buf; - } - - stream->ThenFusedConvolveWithAlgorithm( - input_descriptor, input_buf, params.conv_result_scale, - filter_descriptor, filter_buf, convolution_descriptor, side_input, - params.fusion->side_input_scale, bias_desc, - DeviceMemory(params.fusion->bias_buf), params.fusion->mode, - output_descriptor, &output_buf, scratch_allocator, algorithm, - profile_result); - break; - } - } - - if (!stream->ok()) { - return InternalError( - "Unable to launch convolution with type %s and algorithm (%d, %d)", - CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), - algorithm.algorithm_no_scratch().algo_id()); - } - return Status::OK(); -} - -// Returns the cudnn convolution parameters generated from conv, which must be a -// custom-call to a cudnn convolution. -StatusOr GetCudnnConvParams( - const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer) { - CudnnConvParams params; - - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - conv->backend_config()); - const auto& target = conv->custom_call_target(); - const auto& lhs_shape = conv->operand(0)->shape(); - const auto& rhs_shape = conv->operand(1)->shape(); - const auto& conv_result_shape = conv->shape().tuple_shapes(0); - - params.window = &conv->window(); - params.dnums = &conv->convolution_dimension_numbers(); - params.feature_group_count = conv->feature_group_count(); - params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( - backend_config.algorithm(), backend_config.tensor_ops_enabled())); - params.conv_result_scale = backend_config.conv_result_scale(); - - if (target == kCudnnConvForwardCallTarget) { - params.kind = CudnnConvKind::kForward; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params.kind = CudnnConvKind::kBackwardInput; - params.input_shape = &conv_result_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &lhs_shape; - params.input_buf = result_buffer; - params.filter_buf = operand_buffers[1]; - params.output_buf = operand_buffers[0]; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params.kind = CudnnConvKind::kBackwardFilter; - params.input_shape = &lhs_shape; - params.filter_shape = &conv_result_shape; - params.output_shape = &rhs_shape; - params.input_buf = operand_buffers[0]; - params.filter_buf = result_buffer; - params.output_buf = operand_buffers[1]; - } else if (target == kCudnnConvBiasActivationForwardCallTarget) { - params.kind = CudnnConvKind::kForwardActivation; - params.input_shape = &lhs_shape; - params.filter_shape = &rhs_shape; - params.output_shape = &conv_result_shape; - params.fusion.emplace(); - auto& fusion = *params.fusion; - if (backend_config.activation_mode() < - static_cast(se::dnn::ActivationMode::kNumActivationModes)) { - fusion.mode = static_cast( - backend_config.activation_mode()); - } else { - return InternalError("Bad activation mode: %s", - backend_config.ShortDebugString()); - } - fusion.side_input_scale = backend_config.side_input_scale(); - params.input_buf = operand_buffers[0]; - params.filter_buf = operand_buffers[1]; - params.output_buf = result_buffer; - params.fusion->bias_buf = operand_buffers[2]; - if (operand_buffers.size() >= 4) { - params.fusion->side_input_buf = operand_buffers[3]; - } - } else { - return InternalError("Unexpected custom call target: %s", target); - } - return params; -} - -} // anonymous namespace - -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(conv, operand_buffers, result_buffer, - &scratch_allocator, stream, profile_result); -} - -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - TF_ASSIGN_OR_RETURN(CudnnConvParams params, - GetCudnnConvParams(conv, operand_buffers, result_buffer)); - - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { - case F16: - return RunCudnnConvolutionImpl(params, scratch_allocator, - stream, profile_result); - case F32: - return RunCudnnConvolutionImpl(params, scratch_allocator, stream, - profile_result); - case F64: - return RunCudnnConvolutionImpl(params, scratch_allocator, stream, - profile_result); - default: - LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); - } -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h deleted file mode 100644 index 61aec1ceccec0f253f9ddaa688d64cacea800cf3..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ - -#include "absl/types/optional.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace xla { -namespace gpu { - -// This file contains low-level routines for running cudnn convolutions. - -// Calls into cudnn to run the specified convolution. -// -// We provide one overload which takes a scratch buffer, and another which takes -// an allocator which is responsible for allocating the scratch space. In -// theory the second one shouldn't be necessary -- users of this function could -// just ask cudnn how much scratch space it needs for a particular convolution. -// But in practice, StreamExecutor does not expose such an API, and in the name -// of parsimony, perhaps it's better not to add it. Instead, the first time you -// call a convolution, you should call the version that takes a scratch -// allocator and take note of how much memory is used. The next time you call -// the same conv, you can provide an explicitly preallocated scratch buffer of -// that size, if you like. -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::DeviceMemoryBase scratch_buf, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -Status RunCudnnConvolution(const HloCustomCallInstruction* conv, - absl::Span operand_buffers, - se::DeviceMemoryBase result_buffer, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde65ad5745a3c102d029907e0690dc8c34620fd --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -0,0 +1,280 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast(conv_instr); + auto bias_broadcast = + CastOrNull(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull(alpha_conv_instr), + CastOrNull(alpha_side_input_instr)}; +} + +StatusOr> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast(instr)->literal().Convert(F64)); + return alpha.GetFirstElement(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + new_conv->set_metadata(conv->metadata()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config()); + config.set_activation_mode( + static_cast(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Replacing convolution " << conv->ToString() << " with " + << new_conv->ToString(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr CudnnFusedConvRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..613ed8dbdc33dfc3684deb5fd3ee8f5b9ea5fc50 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class CudnnFusedConvRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7dd07a50c637d514439bb7a8ec799e4cabfee55 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -0,0 +1,310 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; + +class CudnnFusedConvRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_THAT(optimized_hlo_string, + Not(HasSubstr(kCudnnConvForwardCallTarget))); + EXPECT_THAT(optimized_hlo_string, + HasSubstr(kCudnnConvBiasActivationForwardCallTarget)); + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget)); + EXPECT_THAT(optimized_hlo_string, + Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget))); + } + } +}; + +TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) { + const char* kHloString = R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(0) + zeros = f32[1,32,9,9] broadcast(zero), dimensions={} + + input = f32[1,17,9,9] parameter(0) + filter = f32[3,3,17,32] parameter(1) + + conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo"} + ROOT relu = f32[1,32,9,9] maximum(zeros, conv) + })"; + + const string optimized_hlo_string = + backend() + .compiler() + ->RunHloPasses(ParseHloString(kHloString, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + EXPECT_THAT( + optimized_hlo_string, + ::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc deleted file mode 100644 index 3761c19cfcab10e0c6faa17c2d1d535d706ff6c5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc +++ /dev/null @@ -1,278 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/pattern_matcher.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace xla { -namespace gpu { -namespace { - -// Describes a matched pattern: -// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); -// Where side_input has the shape of output buffer, and bias is a 1D array with -// the dimension of number of output features. -struct ConvWithRelu { - HloInstruction* maximum; - HloCustomCallInstruction* conv; - HloInstruction* bias; - HloInstruction* side_input; - HloConstantInstruction* alpha_conv; - HloConstantInstruction* alpha_side_input; -}; - -absl::optional FindConvWithRelu(HloInstruction* instr) { - using match::Add; - using match::AddAnyOrder; - using match::AnyOf; - using match::Broadcast; - using match::Constant; - using match::GetTupleElement; - using match::Maximum; - using match::MultiplyAnyOrder; - using match::Op; - - // The pattern we want to match: - // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); - // - // With its variants involving commute/reassociation of adds, multiplies, and - // max, and omission of alpha1, side_input, alpha2, or bias. - - HloInstruction* relu_input; - - // Match max(0, relu_input). - auto zero_pattern = Broadcast(match::ConstantScalar(0)); - if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && - !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { - return absl::nullopt; - } - HloInstruction* conv_instr = nullptr; - HloInstruction* alpha_conv_instr = nullptr; - HloInstruction* alpha_side_input_instr = nullptr; - HloInstruction* bias_broadcast_instr = nullptr; - HloInstruction* bias = nullptr; - HloInstruction* side_input = nullptr; - - // These nodes will not be in the returned value, but we need to check them - // for single use. - HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, - *mul1 = nullptr, *mul2 = nullptr; - - const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); - const auto conv_pattern = [&] { - auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); - auto conv_pattern = GetTupleElement( - >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); - return AnyOf( - MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); - }(); - const auto side_input_pattern = [&] { - auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); - // If bias is already matched, match arbitrary additional input as side - // input. Note this may force a cheap operation (e.g. broadcast) to be - // materialized into a large buffer, as large as the output buffer. - // - // TODO(timshen): If in practice there are significant false positives, we - // should fix it. - auto side_input_pattern = Op(&side_input); - return AnyOf( - MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), - side_input_pattern); - }(); - - { - // Try to match any of the following form of add, in any association: - // addends[0] - // addends[0] + addends[1] - // addends[0] + addends[1] + addends[2] - // - // Then try to match each addend with one of the three patterns: bias, conv, - // or side_input. Notice that side_input matching must go last, as it - // also matches a conv or a bias. - HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; - auto add3_pattern = [&] { - auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); - return AnyOf( - AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, - Op(&addends[0])); - }(); - CHECK(Match(relu_input, add3_pattern)); - for (auto addend : addends) { - if (addend) { - if (bias == nullptr && Match(addend, bias_pattern)) { - CHECK(bias); - } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { - CHECK(conv_instr); - } else if (side_input == nullptr && Match(addend, side_input_pattern)) { - CHECK(side_input); - } else { - return absl::nullopt; - } - } - } - } - - if (conv_instr == nullptr) { - return absl::nullopt; - } - - for (HloInstruction* instr : - {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { - if (instr && instr->user_count() > 1) { - return absl::nullopt; - } - } - - auto conv = Cast(conv_instr); - auto bias_broadcast = - CastOrNull(bias_broadcast_instr); - - if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { - return absl::nullopt; - } - - if (bias_broadcast) { - // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. - if (bias_broadcast_instr->dimensions().size() != 1) { - return absl::nullopt; - } - if (bias_broadcast_instr->dimensions(0) != - conv->convolution_dimension_numbers().output_feature_dimension()) { - return absl::nullopt; - } - } - - return ConvWithRelu{ - instr, - conv, - bias, - side_input, - CastOrNull(alpha_conv_instr), - CastOrNull(alpha_side_input_instr)}; -} - -StatusOr> TryRewriteToCudnnForwardRelu( - ConvWithRelu match) { - auto conv = match.conv; - - HloComputation* computation = conv->parent(); - PrimitiveType element_type = conv->operand(0)->shape().element_type(); - - const auto get_alpha_value = - [](HloConstantInstruction* instr) -> StatusOr { - TF_ASSIGN_OR_RETURN( - auto alpha, - Cast(instr)->literal().Convert(F64)); - return alpha.GetFirstElement(); - }; - - double alpha_conv = 1; - if (match.alpha_conv) { - TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); - } - - double alpha_side_input; - if (match.side_input) { - if (match.alpha_side_input) { - TF_ASSIGN_OR_RETURN(alpha_side_input, - get_alpha_value(match.alpha_side_input)); - } else { - alpha_side_input = 1; - } - } else { - CHECK(match.alpha_side_input == nullptr); - alpha_side_input = 0; - } - - auto bias = match.bias; - if (!bias) { - auto zero = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); - - int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( - conv->convolution_dimension_numbers().output_feature_dimension()); - bias = computation->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithDescendingLayout(element_type, - {num_output_feature}), - zero, {})); - } - - CHECK(bias); - std::vector args = {conv->mutable_operand(0), - conv->mutable_operand(1), bias}; - if (match.side_input) { - args.push_back(match.side_input); - } - auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( - conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); - new_conv->set_window(conv->window()); - new_conv->set_convolution_dimension_numbers( - conv->convolution_dimension_numbers()); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, - conv->backend_config()); - config.set_activation_mode( - static_cast(se::dnn::ActivationMode::kRelu)); - config.set_conv_result_scale(alpha_conv); - config.set_side_input_scale(alpha_side_input); - TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); - - VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name(); - return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), - new_conv, 0); -} - -} // namespace - -StatusOr CudnnFusedConvolutionRewriter::Run(HloModule* module) { - bool changed = false; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - std::vector matches; - int num_forward_convs = 0; - for (auto instr : computation->instructions()) { - auto match = FindConvWithRelu(instr); - if (match.has_value()) { - matches.push_back(*match); - } - if (auto call = DynCast(instr)) { - if (call->custom_call_target() == kCudnnConvForwardCallTarget) { - num_forward_convs++; - } - } - } - VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() - << " out of " << num_forward_convs << " forward convs."; - std::vector>> - replacements; - for (const ConvWithRelu& match : matches) { - TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); - replacements.push_back({match.maximum, std::move(new_instr)}); - changed = true; - } - for (auto& replacement : replacements) { - TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( - replacement.first, std::move(replacement.second))); - } - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h deleted file mode 100644 index bd12aadded9dd9e19bc695ddc11e5529931a306a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ - -#include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -class CudnnFusedConvolutionRewriter : public HloModulePass { - public: - absl::string_view name() const override { - return "cudnn-fused-convolution-rewriter"; - } - - StatusOr Run(HloModule* module) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04ddc31edf723c056805ae5aad994e55..2ab754a471070d5f90a3eaebd0600ff180d2fe5d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -161,6 +161,16 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); + HloOpcode opcode = op->opcode(); + + if (hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max() && + (opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum)) { + return llvm_ir::EmitCallToIntrinsic( + opcode == HloOpcode::kMaximum ? llvm::Intrinsic::maxnum + : llvm::Intrinsic::minnum, + {lhs_value, rhs_value}, {lhs_value->getType()}, b_); + } + switch (op->opcode()) { case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, @@ -358,13 +368,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +400,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 30c1f9088968305ad0207164ecb07ba13cc89ee6..470457935acacb8940af241dadb393d770786939 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -229,7 +229,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - (user->fusion_kind() == HloInstruction::FusionKind::kInput && + (IsReduceInputFusion(*user) && LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 9c4a4903667ea1a6c99ce9e912c9d0497b8e389f..27f07b1d58125092c1ed6734b238e4ae0f11c4aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -51,7 +52,8 @@ struct MatrixDescriptor { // rhs_matrix, and stores the result to output_matrix. template bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, double beta, + se::Stream* stream) { DCHECK(!output_matrix.transpose); const int64 batch_size = lhs_matrix.batch_size; @@ -73,7 +75,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, lhs_transpose, rhs_transpose, output_matrix.num_rows, output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/beta, &output_data, /*leading dim of output=*/output_matrix.num_rows) .ok(); } @@ -88,7 +90,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, lhs_stride, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, rhs_stride, - /*beta=*/0.0, &output_data, + /*beta=*/beta, &output_data, /*leading dim of output=*/output_matrix.num_rows, output_stride, batch_size) .ok(); @@ -112,6 +114,7 @@ template bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, + double beta, se::blas::ComputationType computation_type, se::blas::AlgorithmType algorithm, se::Stream* stream, se::blas::ProfileResult* output_profile_result) { @@ -138,7 +141,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, /*alpha=*/static_cast(alpha), lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, - /*beta=*/static_cast(0.0f), &output_data, + /*beta=*/static_cast(beta), &output_data, /*leading dim of output=*/output_matrix.num_rows, computation_type, algorithm, output_profile_result) .ok(); @@ -153,7 +156,7 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, template StatusOr DoGemmAutotune( MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, double alpha, + MatrixDescriptor output_matrix, double alpha, double beta, se::blas::ComputationType computation_type, se::Stream* stream) { std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -166,7 +169,7 @@ StatusOr DoGemmAutotune( // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - alpha, computation_type, algorithm, + alpha, beta, computation_type, algorithm, stream, &profile_result)); if (profile_result.is_valid()) { @@ -263,8 +266,9 @@ DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { } CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); - CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), - HloOpcode::kMultiply); + CHECK(hlo_instruction.fused_expression_root()->opcode() == HloOpcode::kAdd || + hlo_instruction.fused_expression_root()->opcode() == + HloOpcode::kMultiply); // Try to find the dot inside the output fusion node. const HloInstruction* dot = hlo_instruction.fused_expression_root()->operand(0); @@ -282,8 +286,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, double alpha, - const HloInstruction* hlo_instruction) + const Shape& output_shape, double alpha, double beta, + const HloInstruction* hlo_instruction, + bool implements_whole_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), @@ -291,7 +296,9 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - alpha_(alpha) {} + alpha_(alpha), + beta_(beta), + implements_whole_instruction_(implements_whole_instruction) {} Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, @@ -386,7 +393,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // TODO(b/112111608): Implement auto tune for batched gemm. if (batch_size != 1) { return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, stream); + alpha_, beta_, stream); } auto thunk_name = [&] { @@ -398,9 +405,27 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, auto autotune_it = autotune_results_.find(device_name); if (autotune_it == autotune_results_.end()) { VLOG(3) << "Starting autotune of GemmThunk " << thunk_name(); - StatusOr best_algorithm = - GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, computation_type, stream); + + // If the output buffer already contains a bias then autotune into a + // scratch buffer. This avoids overwriting the bias buffer. The scratch + // buffer may contain arbitrary garbage values. + se::DeviceMemoryBase scratch_data = output_data; + std::unique_ptr> scratch_mem; + if (beta_ != 0.0) { + auto temp_status = stream->AllocateTemporaryArray( + ShapeUtil::ByteSizeOf(output_shape_)); + if (!temp_status.ok()) { + return false; + } + scratch_mem = std::move(temp_status).ValueOrDie(); + scratch_data = scratch_mem->device_memory(); + } + const MatrixDescriptor scratch_descriptor( + scratch_data, false, output_num_cols, output_num_rows, batch_size); + + StatusOr best_algorithm = GetGemmAutotuneFn( + element_type)(lhs_matrix, rhs_matrix, scratch_descriptor, alpha_, + beta_, computation_type, stream); autotune_it = autotune_results_.insert({device_name, best_algorithm}).first; @@ -421,18 +446,19 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(2) << "Using algorithm " << algorithm << " chosen by autotuning on GemmThunk " << thunk_name(); return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, - algorithm, stream, + lhs_matrix, rhs_matrix, output_matrix, alpha_, beta_, + computation_type, algorithm, stream, /*output_profile_result=*/nullptr); } // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. // Use the older Gemm API in this case. return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - alpha_, stream); + alpha_, beta_, stream); }; - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); + auto op_profiler = profiler->MakeScopedInstructionProfiler( + implements_whole_instruction_ ? hlo_instruction() : nullptr); bool launch_ok; if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { launch_ok = launch(lhs_descriptor, rhs_descriptor, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 12c81f9bfc6bfdac63edf9c826b835057107fa41..cc2d12a39c045fc081292dcf53053f6613d3d9ef 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -41,8 +41,9 @@ class GemmThunk : public Thunk { const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, double alpha, - const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, double beta, + const HloInstruction* hlo_instruction, + bool implements_whole_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -70,6 +71,9 @@ class GemmThunk : public Thunk { const Shape output_shape_; const double alpha_; + const double beta_; + + const bool implements_whole_instruction_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 57426327822d95a42f407ed7488f35acfd3623d2..ae2e718db29803a085401969a7d9b09abf690a6c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -51,7 +51,7 @@ GpuExecutable::GpuExecutable( const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 0e276282e40fba0ae4881a51dad0c7c9e8d1c081..2b3c77f5b82aa94f44d8de56caf0f4d31c05e0cb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { GpuExecutable(const string& ptx, const std::vector& cubin, std::pair compute_capability, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr assignment, std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2d31fd5570c468b0c42fa308535fd335f3588a79..452e763a8eaadc805cd3a3859a68e2a31598fd36 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -55,7 +55,7 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr) { if (instr.IsMultiOutputFusion()) { for (const HloInstruction* operand : instr.fused_expression_root()->operands()) { @@ -67,17 +67,70 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { return true; } } - return false; - } else if (instr.opcode() == HloOpcode::kFusion) { - if (IsReductionToVector(*instr.fused_expression_root())) { - CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) - << " Fusion rooted at reduction-to-vector op must be of kind kInput: " - << instr.ToString(); - return true; + } else if (instr.opcode() == HloOpcode::kFusion && + IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + return IsReduceInputFusion(instr) || IsReductionToVector(instr); +} + +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2) { + // Returns the instructions that determines the emitter used for lowering, + // sometimes referred to as "the real hero". + auto get_real_hero = + [&](const HloInstruction* instr) -> const HloInstruction* { + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduction-to-vector operand of the + // fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionToVector(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } + return fused_expression_root; } + return instr; + }; + + // Multi-output fusion kernels share a common parallel loop. The loop + // dimenstions are determined by instruction shapes. + auto get_loop_shape = [&](const HloInstruction* element_instr) { + // Special-case reduction-to-vector ops: The loop dimensions are determined + // by the shape of the first operand. + if (IsReductionToVector(*element_instr)) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // All shapes of the root tuple of multi-output fusions should agree, i.e. all + // root ops should have equal output shapes. An exception are + // reduction-to-vector ops. Here the input shapes of the reduction (first + // operand shape) and the reduction dimensions need to match. + auto* instr_1 = get_real_hero(&instr1); + auto* instr_2 = get_real_hero(&instr2); + // TODO(tjoerg): Relax the shape constraint. The datatype does not matter. + if (IsReductionToVector(*instr_1) && IsReductionToVector(*instr_2) && + (!ShapeUtil::Equal(instr_1->shape(), instr_2->shape()) || + instr_1->dimensions() != instr_2->dimensions())) { return false; } - return IsReductionToVector(instr); + // The elementwise output shapes must be the same (including layout). + // TODO(tjoerg): Further relax the constraint. The datatype does not matter. + return ShapeUtil::EqualIgnoringFpPrecision(get_loop_shape(instr_1), + get_loop_shape(instr_2)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index f7c24a0d5bbfcc61389ea19ae7f769671e4e974d..e9d7ba1c4cfa865532a0d06c2ed883a2fea4e2cd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -33,16 +33,29 @@ namespace gpu { bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, const HloInstruction& reduce); -// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` -// is either an unfused reduction-to-vector op, an input fusion rooted at a -// reduction-to-vector op, or a multi-output input fusion with at least one -// reduction-to-vector op root. // Note that reduction ops are lowered in different ways. Reduce input fusions // are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at // reduction-to-vector ops. Other reduction ops are lowered by // GpuElementalIrEmitter and fused like elementwise ops. + +// Whether `instr` is an input fusion rooted at a reduction-to-vector op or a +// multi-output input fusion with at least one reduction-to-vector op root. +bool IsReduceInputFusion(const HloInstruction& instr); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op or a reduce input fusion. bool IsInputFusibleReduction(const HloInstruction& instr); +// Whether instruction shapes are compatible for multi-output fusion, i.e. +// whether the emitters support lowering the resulting fusion. +// This function works for both, sibling and producer-conumser multi-output +// fusion. +// So far, multi-output fusion is supported for loop fusions and reduce +// input fusions only. It is up to the caller to ensure the instructions +// themselves are fusible! +bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, + const HloInstruction& instr2); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc index d91b7bc61fda5a07c163a07ec0e1644d2ad9db49..15d4ee206ce8debcb8a5dbc6ec65d29ba257d302 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -178,7 +178,7 @@ TEST_F(GpuFusibleTest, EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -191,10 +191,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY entry { c0 = f32[] parameter(0) @@ -207,10 +208,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -225,10 +227,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -243,10 +246,11 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputInputReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -263,11 +267,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + IsReduceInputFusion_MultiOutputInputReduceFusionWithExtraOutputs) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -284,10 +289,11 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsReduceInputFusion(*reduce)); EXPECT_TRUE(IsInputFusibleReduction(*reduce)); } -TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { +TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputLoopReduceFusion) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -304,11 +310,12 @@ TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } TEST_F(GpuFusibleTest, - IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + IsReduceInputFusion_MultiOutputLoopFusionReduceAndElementwiseOp) { auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_reduction { c0 = f32[] parameter(0) @@ -325,8 +332,304 @@ TEST_F(GpuFusibleTest, const HloInstruction* reduce = module->entry_computation()->root_instruction(); ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsReduceInputFusion(*reduce)); EXPECT_FALSE(IsInputFusibleReduction(*reduce)); } +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_LoopFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_IgnoreFpPrecision) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + ROOT convert = f16[6400]{0} convert(p0.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Reduce) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_Elementwise) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* div = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion, *div)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputLoopFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_UnfusedOps) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_DifferentLayouts) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{0,1,2} parameter(1) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + reduce = f32[2,2]{0,1} reduce(p1, c0), dimensions={2}, to_apply=scalar_add + ROOT root = (f32[2,2]{0,1}, f32[2,2,2]{2,1,0}) tuple(reduce, exp) + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*reduce, *exp)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_MultiOutputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1)->operand(0); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_ReduceFusions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={0}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_DifferentReduceDimensions) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce_1 { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} p0.1, f32[] c0), dimensions={0}, to_apply=scalar_add + } + + fused_reduce_2 { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={2}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + reduce_1 = f32[2,2]{1,0} fusion(p0), kind=kLoop, calls=fused_reduce_1 + reduce_2 = f32[2,2]{1,0} fusion(p1), kind=kLoop, calls=fused_reduce_2 + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce_1, reduce_2) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + +TEST_F(GpuFusibleTest, + ShapesCompatibleForMultiOutputFusion_NoReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_element_wise { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2) + c1 = f32[] constant(0) + // Note that reduce is not a reduction-to-vector. + ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise) + })")) + .ValueOrDie(); + const HloInstruction* fusion_1 = + module->entry_computation()->root_instruction()->operand(0); + const HloInstruction* fusion_2 = + module->entry_computation()->root_instruction()->operand(1); + EXPECT_FALSE(ShapesCompatibleForMultiOutputFusion(*fusion_1, *fusion_2)); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 02a0d028c118aba23996f9b97d05443bb4a00c88..1126943624a3771433ecac591545d335c1890115 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -37,12 +37,12 @@ class GpuHloOrdering : public PredecessorHloOrdering { public: GpuHloOrdering(const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order); + const std::vector& thunk_launch_order); ~GpuHloOrdering() override = default; // Only the entry computation can possibly be sequentially ordered, and only // if we've assigned all instructions to a single stream. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return &computation == module_->entry_computation() ? entry_sequence_.get() : nullptr; @@ -51,17 +51,17 @@ class GpuHloOrdering : public PredecessorHloOrdering { string ToString() const override { return ToStringHelper("GpuHloOrdering"); } private: - std::unique_ptr> entry_sequence_; + std::unique_ptr entry_sequence_; }; GpuHloOrdering::GpuHloOrdering( const HloModule* module, const StreamAssignment& stream_assignment, - const std::vector& thunk_launch_order) + const std::vector& thunk_launch_order) : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = absl::make_unique>( - thunk_launch_order); + entry_sequence_ = + absl::make_unique(thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -124,7 +124,8 @@ GpuHloOrdering::GpuHloOrdering( for (auto* computation : module->computations()) { if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, + HloReachabilityMap::Build(computation)); } } } @@ -149,7 +150,7 @@ GpuHloOrdering::GpuHloOrdering( // However, if the total order is A,B,D,C,E, then C and E can run // concurrently. void BFSLaunchOrder(const HloComputation* computation, - std::vector* launch_order) { + std::vector* launch_order) { // This topological sort uses two data structures: // 1. `incoming_edge_count` which keeps track of the number of incoming // edges to each HLO; @@ -157,9 +158,9 @@ void BFSLaunchOrder(const HloComputation* computation, // // The sorting algorithm repeatedly pops the top from the queue and deletes // that HLO from the graph, making more HLOs incoming-edge free. - std::deque queue; + std::deque queue; std::unordered_map incoming_edge_count; - for (const auto& hlo : computation->instructions()) { + for (auto* hlo : computation->instructions()) { if (hlo->operand_count() == 0) { queue.push_back(hlo); } else { @@ -171,10 +172,10 @@ void BFSLaunchOrder(const HloComputation* computation, } while (!queue.empty()) { - const HloInstruction* x = queue.front(); + HloInstruction* x = queue.front(); queue.pop_front(); launch_order->push_back(x); - for (const HloInstruction* y : x->users()) { + for (HloInstruction* y : x->users()) { --incoming_edge_count[y]; if (incoming_edge_count[y] == 0) { queue.push_back(y); @@ -194,14 +195,14 @@ StatusOr> GpuHloSchedule::Build( std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. - const HloComputation* entry_computation = module.entry_computation(); + HloComputation* entry_computation = module.entry_computation(); if (stream_assignment.StreamCount() == 1) { // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( HloInstructionSequence sequence, ScheduleComputation( - *entry_computation, [pointer_size](const BufferValue& buffer) { + entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); schedule->thunk_launch_order_ = sequence.instructions(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 07a7fc67aa555845c3de57e574ab582403ec0490..7f224ffe4f03f8f05b0f1907628d99d9df387770 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -46,7 +46,7 @@ class GpuHloSchedule { // Returns the total order of thunk launches, represented in terms of HLO // instructions. - const std::vector& ThunkLaunchOrder() const { + const std::vector& ThunkLaunchOrder() const { return thunk_launch_order_; } @@ -60,7 +60,7 @@ class GpuHloSchedule { private: GpuHloSchedule(); - std::vector thunk_launch_order_; + std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index b857fa775a76ec999b505a2a64332cc0c54cf00b..91db7151f22fd75b20244878bee86d65acd1d304 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,16 +24,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloVerifiedTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: - using HloVec = std::vector; + using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); @@ -44,7 +44,7 @@ class GpuHloScheduleTest : public HloVerifiedTestBase { .ConsumeValueOrDie(); } - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -79,7 +79,7 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr streams = AssignStreams(*module); @@ -139,7 +139,7 @@ TEST_F(GpuHloScheduleTest, SequentialAdd) { HloInstruction* add3 = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add3)); std::unique_ptr streams = AssignStreams(*module); @@ -209,7 +209,7 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr streams = AssignStreams(*module); @@ -288,7 +288,7 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr streams = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 27a4d0b601f3807fe6b94dd6171a44f292921ede..b511155f85fb24adc1828cbef7f3fb60778ef7ab 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloVerifiedTestBase { +class GpuHloSupportCheckerTest : public HloTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -42,10 +42,10 @@ TEST_F(GpuHloSupportCheckerTest, Add) { HloInstruction::CreateParameter(1, scalar_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module).status()); + TF_ASSERT_OK(checker().Run(module.get()).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -57,10 +57,13 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { HloInstruction::CreateParameter(1, sparse_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( sparse_shape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); + // Since verifier is reporting sparse layouts as errors, we should + // use a regular HloModule instead of VerifiedHloModule to avoid + // verifier errors being triggered in the destructor. + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module).status(); + Status status = checker().Run(module.get()).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 74352f26aa9c3a2ca597da21735438df92f863ab..f59da2caa18646676297e66dd329c66fb5fddf1b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -66,8 +65,8 @@ HeuristicLayoutAssignment(const HloInstruction* instr, VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); - // Empirically we've found with Volta and cudnn 7 that backward-input convs - // with stride are significantly faster with NCHW layouts. + // Empirically we've found with Volta and cudnn <= 7.3 that backward-input + // convs with stride are significantly faster with NCHW layouts. // // We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW), // which on paper gives good performance. However, there are two observations: @@ -76,11 +75,17 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // * we've also observed that for mixed layouts, cuDNN transposes data back // and forth from a different layout combination. If we end up with // transposes anyway, we prefer to have them in XLA, as they can be fused. - // TODO(timshen): Figure out the exact condition. This may be achieved by - // auto-tuning layouts offline. - if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && - window_util::HasStride(instr->window())) { - return kAllNCHW; + if (auto* dnn = stream_executor->AsDnn()) { + auto version_status = dnn->GetVersion(); + if (version_status.ok()) { + auto version = version_status.ConsumeValueOrDie(); + if (std::make_tuple(version.major_version(), version.minor_version()) <= + std::make_tuple(7, 3) && + instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return kAllNCHW; + } + } } // For other Volta f16 convolutions, use NHWC. @@ -125,14 +130,8 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( std::tie(*input_shape->mutable_layout(), @@ -215,21 +214,37 @@ Status GpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(op1_shape, instruction, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); + } else if (instruction->opcode() == HloOpcode::kSort && + ShapeUtil::Rank(instruction->operand(0)->shape()) > 1) { + // Make sure that all the operands and the output(s) have the same layout. + Shape keys_shape = instruction->operand(0)->shape(); + Layout keys_layout = + LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(keys_shape)); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + Shape shape = instruction->operand(i)->shape(); + *shape.mutable_layout() = keys_layout; + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(shape, instruction, i)); + const LogicalBuffer* output_buffer; + if (ShapeUtil::IsArray(instruction->shape())) { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {})); + } else { + TF_ASSIGN_OR_RETURN( + output_buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {i})); + } + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(keys_layout, *output_buffer)); + } } } return Status::OK(); } -bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) { - // - Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - // - Inputs to cudnn convolution require custom layouts handled in - // AddBackendConstraints. - return !IsCustomCallToDnnBatchNorm(*instruction) && - !IsCustomCallToDnnConvolution(*instruction); -} - Status GpuLayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 4ba7989e9cba9abe6cdc1fcabd5f011bd9cfb0ec..6a48e55fd2e784f80a50f4565107db177fb43bfc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -46,8 +46,6 @@ class GpuLayoutAssignment : public LayoutAssignment { Status PropagateBufferConstraint( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) override; - bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) override; private: Status AddBackendConstraintsToDnnConvCustomCall( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 04681cfcec792d86eed95585262691932b07b269..2ffc8bfb49b205dced0d540ba72426e72d95e596 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -61,7 +61,7 @@ TEST_F(LayoutAssignmentTest, Elementwise) { HloInstruction::CreateParameter(1, ashape, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(add)); @@ -148,7 +148,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { {operand, scale, offset, mean, variance, epsilon, feature_index}, kCudnnBatchNormForwardInferenceCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -217,7 +217,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, kCudnnBatchNormForwardTrainingCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -298,7 +298,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { feature_index}, kCudnnBatchNormBackwardCallTarget)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build(batchnorm)); @@ -351,7 +351,8 @@ TEST_F(LayoutAssignmentTest, DotLayout) { ParseHloString(hlo_text)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( &computation_layout, LayoutAssignment::InstructionCanChangeLayout, backend().default_stream_executor()); @@ -364,6 +365,34 @@ TEST_F(LayoutAssignmentTest, DotLayout) { op::ShapeWithLayout(expected_shape))); } +TEST_F(LayoutAssignmentTest, SortLayout) { + const char* hlo_text = R"( + HloModule SortLayout + ENTRY sort { + keys = f32[3,2]{0,1} constant(f32[3,2]{0,1}{{0,1},{0,1},{0,1}}) + values = f32[2,3]{1,0} parameter(0) + transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0} + ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose), + dimensions={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {3, 2}, {1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sort(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc deleted file mode 100644 index 35b4b4e20b633792de4251a4b0e89f4b579053ce..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" -#include "tensorflow/core/lib/gtl/map_util.h" - -namespace xla { -namespace gpu { - -bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { - return !config.debug_options().xla_backend_extra_options().count( - "xla_gpu_experimental_conv_disable_layout_heuristic"); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h deleted file mode 100644 index 498d4a94955cb2c50e0b165f28ded44ac1c0bfff..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ - -#include "tensorflow/compiler/xla/service/hlo_module_config.h" - -// Helper functions for querying options that are specific to the GPU backend. - -namespace xla { -namespace gpu { - -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b61f0387392d2301109a484ca5c1f65f18882265..6151dd8ff4c92bb81bd756c68cc9377633c8c9d5 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -47,6 +47,8 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || + hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; } @@ -78,7 +80,7 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { // This function limits the maximum number of operands to a fusion. // // There's a cap on how many parameters we can pass to a CUDA kernel, but -// exactly what that limit is is hazy, as it depends on (among other things) how +// exactly what that limit is hazy, as it depends on (among other things) how // much GPU constant memory is in use for other purposes. // // Moreover, we don't even know at the point that we're running fusion how many @@ -178,6 +180,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kAdd && + consumer->operand(other_operand_index) != producer) { + // Fuse a bias add into the output of the dot. + return true; } } @@ -223,6 +230,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Scatter is only supported at the root of a kInput fusion. + if (producer->opcode() == HloOpcode::kScatter) { + return false; + } + // Do not fuse into reduce input fusions if the resulting kernel would suffer // from poor data locality (due to unfriendly input layouts). if (IsInputFusibleReduction(*consumer) && @@ -246,12 +258,17 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // Fuse scalar constants into loop fusion nodes, this reduces the number of + // Fuse scalar constants into loop fusion nodes. This reduces the number of // parameters and makes matching scalar broadcasts easier. - if (ShapeUtil::IsEffectiveScalar(producer->shape()) && - consumer->opcode() == HloOpcode::kFusion && - producer->opcode() == HloOpcode::kConstant) { - return true; + // + // Don't fuse other constants: Unfused constants in GPU land can be + // represented as an external constant (i.e. not emitted in LLVM IR / PTX), + // but fused constants are handled by shrared CPU/GPU code and always emitted + // in the IR/PTX. The external constant representation makes for faster + // compiles and significantly smaller assembly code. + if (producer->opcode() == HloOpcode::kConstant) { + return ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion; } if (!IsFusible(*producer) || !IsFusible(*consumer) || @@ -285,7 +302,8 @@ bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - if (IsReductionToVector(*consumer)) { + if (IsReductionToVector(*consumer) || + consumer->opcode() == HloOpcode::kScatter) { return HloInstruction::FusionKind::kInput; } if (producer->opcode() == HloOpcode::kDot || diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 96bfe0c12eb9cd6ef25804d6b34767471616f7e4..688604cd36e5a45debf855aacd29d05ecda92341 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -41,7 +41,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), exp1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -61,7 +61,7 @@ TEST_F(InstructionFusionTest, builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(S32, {1}), negate1, {0})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(broadcast2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -80,7 +80,7 @@ TEST_F(InstructionFusionTest, HloInstruction* reshape2 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -99,7 +99,7 @@ TEST_F(InstructionFusionTest, HloInstruction* transpose2 = builder.AddInstruction( HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -134,7 +134,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose2, computation->root_instruction()); EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) @@ -331,6 +331,56 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Broadcast(op::Constant()))); } +TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,4]{1,0} parameter(2) + transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add = f32[4,4] add(dot, p2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Parameter())); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool fused_something, + GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { @@ -696,7 +746,7 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { sum = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sum, param)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(b.Build()); EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) @@ -709,5 +759,95 @@ TEST_F(InstructionFusionTest, AvoidsLargeFusion) { } } +TEST_F(InstructionFusionTest, FuseIntoScatter) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY FuseIntoScatter { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT add = s32[3,3] add(scatter, scatter) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Fusion(), op::Fusion())); + EXPECT_EQ(root->operand(0)->fusion_kind(), + HloInstruction::FusionKind::kInput); + EXPECT_THAT(root->operand(0)->fused_expression_root(), + op::Scatter(op::Add(), op::Add(), op::Add())); +} + +TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY BroadcastIntoReduce { + constant = f32[16] constant({0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}) + broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={0} + constant.1 = f32[] constant(0) + ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3}, + to_apply=add + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // The f32[16] constant should not be fused into the reduce, but the f32[] + // constant should be. + auto* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_instructions_computation()->root_instruction(), + op::Reduce(op::Broadcast(op::Parameter()), op::Constant())); +} + +TEST_F(InstructionFusionTest, FuseReverse) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY Reverse { + p0 = f32[50,96,1024]{2,1,0} parameter(0) + add = f32[50,96,1024]{2,1,0} add(p0, p0) + ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reverse(op::Add(op::Parameter(), op::Parameter()))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index ec3d8f9405840bb7be97ba5cd5725a4ac68a15a8..33e41a2782b5932430eea621d3cea2c6634f292f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,10 +38,9 @@ namespace gpu { namespace { -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { - return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && - !LayoutUtil::IsPadded(shape); +// Return whether the given shape is rank 2 excluding the batch dimensions. +bool IsRank2(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2; } // In a gemm operation where output = lhs * rhs, check whether the given shapes @@ -56,10 +55,9 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || output_primitive_type == F64 || output_primitive_type == C64); - return type_is_allowed && - IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && - IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && - IsRank2WithNoPadding(output_shape, batch_dimensions_size) && + return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) && + IsRank2(rhs_shape, batch_dimensions_size) && + IsRank2(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -93,7 +91,8 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kFusion && hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && - hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { + (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply || + hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) { // Try to find the dot inside the output fusion node. const HloInstruction* dot = hlo.fused_expression_root()->operand(0); if (dot->opcode() != HloOpcode::kDot) { @@ -269,5 +268,17 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { + return b->CreateAnd( + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)), + b->CreateICmpEQ( + b->getInt32(0), + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b))); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index a64a616ab1329422d0197f4a7f99ec557a95f8ed..ebf4d926b7a280e10b09a2532caba7ad6ab3ceb2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -108,9 +108,9 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); // memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value // is not well-defined. // -// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// CudnnConvRewriter lowers kConvolution HLOs to these custom calls. // When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later -// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit // algorithm for each conv and sets the amount of scratch space needed. // // (Representing the scratch memory as an output may seem strange at first, but @@ -155,6 +155,10 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Emits code that determines whether the current thread is thread 0 within +// block 0 of the kernel. +llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index b7c37bcf3ca910f10d18339dfe7f1d29f2a55c9e..6693f66d62d8b04d1b78e001fdb515b34539c67f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -63,9 +63,6 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, &ir_emitter_context->buffer_assignment(), &b_, module_, is_nested), hlo_module_config_(hlo_module_config) { - b_.setFastMathFlags(llvm_ir::GetFastMathFlags( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { @@ -97,6 +94,18 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) { + VLOG(2) << "HandleAddDependency: " << add_dependency->ToString(); + const HloInstruction* operand = add_dependency->operand(0); + // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value + // sometimes, e.g., when it's operand is a constant or a bitcast of a + // constant. + if (bindings_.BoundToIrValue(*operand)) { + bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand)); + } + return Status::OK(); +} + Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { auto operand = get_tuple_element->operand(0); CHECK(bindings_.BoundToIrValue(*operand)); @@ -179,6 +188,21 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; llvm::Value* source = Load(source_address, "source"); + + // kCopy of RHS -> atomic store. + if (root_opcode == HloOpcode::kCopy && + (element_type == F32 || is_atomic_integral) && + computation.root_instruction()->operand(0)->opcode() == + HloOpcode::kParameter && + computation.root_instruction()->operand(0)->parameter_number() == 1) { + llvm::StoreInst* store = Store(source, output_address); + store->setAtomic(llvm::AtomicOrdering::Unordered); + // Derive a minimum alignment from the type. The optimizer can increase it + // later. + store->setAlignment(ShapeUtil::ByteSizeOfPrimitiveType(element_type)); + return true; + } + if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -480,18 +504,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) && !ShapeUtil::IsScalar(rhs_shape)); - // Reduce along the last dimension of the LHS and the second-to-last dimension - // of the RHS. Vectors are a special case where the reduction dimension is 0 - // for both LHS and RHS. This results in a vector dot product producing a - // scalar. - const int64 lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, -1); - const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() - ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : dnums.lhs_batch_dimensions_size(); - - // Check that the batch dims don't cover the last two dims. + const int64 lhs_reduction_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_reduction_dimension = dnums.rhs_contracting_dimensions(0); + + // Check that the batch dims don't cover the reduction dimensions. for (int64 batch_dim : dnums.lhs_batch_dimensions()) { CHECK_NE(lhs_reduction_dimension, batch_dim); CHECK_NE(rhs_reduction_dimension, batch_dim); @@ -499,7 +515,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == - rhs_shape.dimensions(rhs_reduction_dimension)); + rhs_shape.dimensions(rhs_reduction_dimension)) + << "lhs_shape.dimensions(" << lhs_reduction_dimension + << ") = " << lhs_shape.dimensions(lhs_reduction_dimension) + << ", and rhs_shape.dimensions(" << rhs_reduction_dimension + << ") = " << rhs_shape.dimensions(rhs_reduction_dimension); // Create loop nests which loop through the LHS operand dimensions and the RHS // operand dimensions. The reduction dimension of the LHS and RHS are handled @@ -686,15 +706,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { // kFusion for library calls should be handled by // IrEmitterUnnested::HandleFusion. - CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind()); - - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } + CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind()); GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 880520148005838cc25a5be9e26c8bc9028a70ce..2da46c016935d0e927879bbfb0d05cfc4899d818 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -68,6 +68,9 @@ namespace gpu { class IrEmitter : public DfsHloVisitorWithDefault, public IrBuilderMixin { public: + using GeneratorForOperandIrArrays = + std::function()>; + IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -97,6 +100,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -179,6 +183,20 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Hlo configuration data used during code generation. const HloModuleConfig& hlo_module_config_; + protected: + GeneratorForOperandIrArrays GetGeneratorForOperandIrArrays( + HloInstruction* fusion) { + return [=]() { + std::vector ir_arrays; + ir_arrays.reserve(fusion->operand_count()); + absl::c_transform(fusion->operands(), std::back_inserter(ir_arrays), + [&](const HloInstruction* operand) { + return GetIrArray(*operand, *fusion); + }); + return ir_arrays; + }; + } + private: // A helper method for EmitAtomicOperationForNestedComputation. Certain // computations, such as floating-point addition and integer maximization, can diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c792dd2ddb0faeba076548ba104aa291e0814140..fb040aff30d48bf5817946ce53d37bc6685941e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -34,6 +34,7 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -43,7 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" @@ -64,11 +65,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -87,6 +88,8 @@ limitations under the License. namespace xla { namespace gpu { +using llvm_ir::KernelMappingScheme; + namespace { using absl::InlinedVector; @@ -336,34 +339,26 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - int unroll_factor = 1; - // Unfused elementwise operations are usually memory bound, unroll them. - if (hlo->IsElementwise()) { - unroll_factor = ComputeMaxUnrollFactor(hlo); - } - - thunk_sequence_->emplace_back(BuildKernelThunk( - hlo, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::DefaultAction(hlo); } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { if (ImplementedAsGemm(*dot)) { - thunk_sequence_->emplace_back(BuildGemmThunk(dot)); + AddThunkToThunkSequence(BuildGemmThunk(dot)); return Status::OK(); } - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); return IrEmitter::HandleDot(dot); } Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { - thunk_sequence_->emplace_back(BuildConditionalThunk(conditional)); + AddThunkToThunkSequence(BuildConditionalThunk(conditional)); return Status::OK(); } Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( BuildKernelThunk(convolution, /*implements_whole_instruction=*/true)); return IrEmitter::HandleConvolution(convolution); } @@ -385,7 +380,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { CHECK(feature_index->IsConstant()); int64 feature_index_value = feature_index->literal().Get({}); - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), @@ -415,7 +410,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), @@ -446,20 +441,19 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back( - absl::make_unique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + AddThunkToThunkSequence(absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -474,7 +468,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - thunk_sequence_->emplace_back(absl::make_unique( + AddThunkToThunkSequence(absl::make_unique( Cast(custom_call), std::move(operand_slices), conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); @@ -487,19 +481,68 @@ Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { TF_RET_CHECK( LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); - thunk_sequence_->emplace_back(BuildFftThunk(fft)); + AddThunkToThunkSequence(BuildFftThunk(fft)); return Status::OK(); } Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); - // HandleFusion specializes reduction from a multi-dimensional array to a 1D - // array. The specialized version requires a initializer thunk that - // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kScatter: { + std::vector> thunks; + // The initialization from 'operand' is using different loop bounds, so + // emit it in a separate kernel. Treat it like a loop fusion, writing to + // the output buffer. + { + int unroll_factor = ComputeMaxUnrollFactor(fusion); + thunks.push_back(BuildKernelThunk( + fusion, /*implements_whole_instruction=*/false, unroll_factor)); + + GpuElementalIrEmitter operand_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter operand_fused_emitter( + GetGeneratorForOperandIrArrays(fusion), + &operand_elemental_emitter); + TF_RETURN_IF_ERROR( + root->mutable_operand(0)->Accept(&operand_fused_emitter)); + + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( + *fusion, operand_fused_emitter.GetGenerator(root->operand(0)), + static_cast(thunks.back().get()))); + } + + // Now build the actual scatter, reading and writing to the freshly + // filled output buffer. + { + thunks.push_back( + BuildKernelThunk(fusion, + /*implements_whole_instruction=*/false)); + // Spin up a new fused emitter for the scatter kernel and emit it. + GpuElementalIrEmitter scatter_elemental_emitter( + hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, + GetNestedComputer()); + FusedIrEmitter scatter_fused_emitter( + GetGeneratorForOperandIrArrays(fusion), + &scatter_elemental_emitter); + TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter)); + TF_RETURN_IF_ERROR(EmitScatter( + thunks.back().get(), root, + /*scatter_indices_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(1)), + /*updates_gen=*/ + scatter_fused_emitter.GetGenerator(root->operand(2)))); + } + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), fusion)); + return Status::OK(); + } case HloOpcode::kTuple: case HloOpcode::kReduce: { + // HandleFusion specializes reduction from a multi-dimensional array to + // a 1D array. The specialized version requires a initializer thunk that + // initializes the output array to the initial value of the reduce. if (root->opcode() == HloOpcode::kReduce && ShapeUtil::IsTuple(root->shape())) { // TODO(b/112040122): Support variadic reduce. @@ -528,18 +571,13 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } } CHECK(first_reduce != nullptr); - thunks.push_back( - BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - absl::make_unique(std::move(thunks), fusion)); - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *fusion)); - } + std::unique_ptr kernel_thunk = + BuildKernelThunk(fusion, /*implements_whole_instruction=*/false); GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion), + &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); // For multi-output fusion CHECK the constraints and feed all the @@ -586,10 +624,15 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } } const Shape& input_shape = first_reduce->operand(0)->shape(); - return EmitReductionToVector(first_reduce, input_shape, input_gens, - init_value_gens, - first_reduce->dimensions(), reducers, - reduce_output_shapes, extra_output_gens); + TF_CHECK_OK(EmitReductionToVector( + kernel_thunk.get(), first_reduce, input_shape, input_gens, + init_value_gens, first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens)); + thunks.push_back(std::move(kernel_thunk)); + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), fusion); + AddThunkToThunkSequence(std::move(sequential_thunk)); + return Status::OK(); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -603,12 +646,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. - thunk_sequence_->emplace_back( - BuildKernelThunk(fusion, /*implements_whole_instruction=*/true)); - std::vector operand_arrays; - for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand, *fusion)); - } + std::unique_ptr fusion_thunk = + BuildKernelThunk(fusion, /*implements_whole_instruction=*/true); GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); @@ -622,18 +661,17 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); - CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); - UpdateLaunchDimensions(launch_dimensions, - static_cast(LastThunk()), + UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(), ir_emitter_context_->llvm_module()); + AddThunkToThunkSequence(std::move(fusion_thunk)); return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( - fusion, operand_arrays, output_array, &elemental_emitter, - launch_dimensions, &b_); + fusion, GetGeneratorForOperandIrArrays(fusion), output_array, + &elemental_emitter, launch_dimensions, &b_); } if (ImplementedAsGemm(*fusion)) { - thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); + AddThunkToThunkSequence(BuildGemmThunk(fusion)); return Status::OK(); } @@ -643,10 +681,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int unroll_factor = ComputeMaxUnrollFactor(fusion); - - thunk_sequence_->emplace_back(BuildKernelThunk( - fusion, /*implements_whole_instruction=*/true, unroll_factor)); return IrEmitter::HandleFusion(fusion); } @@ -657,7 +691,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), copy->shape().layout()) && buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { - thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); + AddThunkToThunkSequence(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } if (CheckAndEmitHloWithTile021(copy)) { @@ -685,7 +719,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( } Status IrEmitterUnnested::EmitReductionToScalar( - HloInstruction* reduce, const Shape& input_shape, + KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -888,18 +922,16 @@ Status IrEmitterUnnested::EmitReductionToScalar( }; // Emit a parallel loop that iterates through all input tiles, one per thread. - CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); - UpdateLaunchDimensions( - launch_dimensions, - static_cast(LastThunk())->thunks().back().get(), - ir_emitter_context_->llvm_module()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } Status IrEmitterUnnested::EmitColumnReduction( - int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, + KernelThunk* kernel_thunk, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -1151,17 +1183,14 @@ Status IrEmitterUnnested::EmitColumnReduction( }; // Emit a parallel loop that iterate through all input tiles. - CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); - UpdateLaunchDimensions( - launch_dimensions, - static_cast(LastThunk())->thunks().back().get(), - ir_emitter_context_->llvm_module()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); } -static std::pair ComputeTilingSchemeForReduction( +static std::pair ComputeKernelMappingSchemeForReduction( int64 depth, int64 width, int64 kWarpSize) { constexpr int64 kTargetNumElementsPerThread = 64; int64 x_tile_size = kTargetNumElementsPerThread; @@ -1186,8 +1215,8 @@ static std::pair ComputeTilingSchemeForReduction( } Status IrEmitterUnnested::EmitRowReduction( - int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, + KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -1295,7 +1324,7 @@ Status IrEmitterUnnested::EmitRowReduction( int64 x_tile_size; int64 z_tile_size; std::tie(x_tile_size, z_tile_size) = - ComputeTilingSchemeForReduction(depth, width, kWarpSize); + ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize); // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. @@ -1522,11 +1551,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; // Emit a parallel loop that iterates through every input tiles. - CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); - UpdateLaunchDimensions( - launch_dimensions, - static_cast(LastThunk())->thunks().back().get(), - ir_emitter_context_->llvm_module()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &b_) .EmitLoop(IrName(reduce), index_ty); @@ -1539,7 +1565,7 @@ Status IrEmitterUnnested::EmitRowReduction( // and, if `reduce` is fused, the fused subgraph is pure // elementwise. Status IrEmitterUnnested::EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, + KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span dimensions_to_reduce, @@ -1580,8 +1606,8 @@ Status IrEmitterUnnested::EmitReductionToVector( // the dimensions to keep are contiguous, by prerequisite of // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. - if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gens, + if (ShapeUtil::IsEffectiveScalar(reduce->shape())) { + return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens, init_value_gens, reducers, reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == @@ -1600,9 +1626,9 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gens, - init_value_gens, reducers, reduce_output_shapes, - extra_output_gens); + return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape, + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1627,8 +1653,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); - return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers, + return EmitRowReduction(kernel_thunk, depth, height, width, reduce, + input_shape, input_gens, init_value_gens, reducers, reduce_output_shapes, extra_output_gens); } } @@ -1650,28 +1676,40 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { BuildInitializerThunk(reduce)); std::vector> thunks; thunks.push_back(std::move(initializer_thunk)); - thunks.push_back( - BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - absl::make_unique(std::move(thunks), reduce)); + std::unique_ptr kernel_thunk = + BuildKernelThunk(reduce, /*implements_whole_instruction=*/false); - return EmitReductionToVector( - reduce, input->shape(), {[&](const IrArray::Index& index) { + TF_CHECK_OK(EmitReductionToVector( + kernel_thunk.get(), reduce, input->shape(), + {[&](const IrArray::Index& index) { return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_); }}, {[&](const IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &b_); }}, - dimensions_to_reduce, {reducer}, {{}}, {}); + dimensions_to_reduce, {reducer}, {{}}, {})); + + thunks.push_back(std::move(kernel_thunk)); + + std::unique_ptr sequential_thunk = + absl::make_unique(std::move(thunks), reduce); + AddThunkToThunkSequence(std::move(sequential_thunk)); + return Status::OK(); } - thunk_sequence_->emplace_back( - BuildKernelThunk(reduce, /*implements_whole_instruction=*/true)); return IrEmitter::HandleReduce(reduce); } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { + // For the root node of the entry computation we can elide writing the tuple + // buffer. We can always figure out the contents of the tuples from buffer + // assignment because we insert copies to ensure non-ambiguous output buffers. + // GpuExecutable never reads the tuple buffer. + if (tuple == + tuple->parent()->parent()->entry_computation()->root_instruction()) { + return Status::OK(); + } bool all_tuple_elements_have_buffer = absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() @@ -1695,11 +1733,11 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(absl::make_unique( + AddThunkToThunkSequence(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( BuildKernelThunk(tuple, /*implements_whole_instruction=*/true)); return IrEmitter::HandleTuple(tuple); } @@ -1727,8 +1765,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back(absl::make_unique( - std::move(thunks), select_and_scatter)); + std::unique_ptr select_and_scatter_thunk = + absl::make_unique(std::move(thunks), select_and_scatter); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1894,8 +1932,9 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // consisting of two thunks, an initializer KernelThunk that initializes // the output and another KernelThunk that accumulates the scattered // elements. - static_cast(LastThunk())->thunks().back().get(), + select_and_scatter_thunk->thunks().back().get(), ir_emitter_context_->llvm_module()); + AddThunkToThunkSequence(std::move(select_and_scatter_thunk)); return ParallelLoopEmitter(loop_body_emitter, source->shape(), launch_dimensions, &b_) .EmitLoop(IrName(select_and_scatter), index_type); @@ -1909,10 +1948,10 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { // Build ForThunk for conformant while loops, otherwise build WhileThunk. // TODO(b/112163966): Move trip count computation earlier in the pipeline. if (auto loop_trip_count = ComputeWhileLoopTripCount(xla_while)) { - thunk_sequence_->emplace_back(BuildForThunk(xla_while, *loop_trip_count)); + AddThunkToThunkSequence(BuildForThunk(xla_while, *loop_trip_count)); VLOG(3) << "Built ForThunk for while: " << xla_while->name(); } else { - thunk_sequence_->emplace_back(BuildWhileThunk(xla_while)); + AddThunkToThunkSequence(BuildWhileThunk(xla_while)); VLOG(3) << "Built WhileThunk for while: " << xla_while->name(); } return Status::OK(); @@ -1923,79 +1962,257 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { // // Unroll the kernel so that the duplicated computation that calculates the // 128 bit sample can be optimized away by LLVM. - thunk_sequence_->emplace_back( - BuildKernelThunk(rng, /*implements_whole_instruction=*/false, - ComputeMaxUnrollFactor(rng))); + std::unique_ptr rng_thunk = BuildKernelThunk( + rng, /*implements_whole_instruction=*/false, ComputeMaxUnrollFactor(rng)); ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { return GetIrArray(*operand, *rng).EmitReadArrayElement(index, &b_); }; } - TF_RETURN_IF_ERROR(EmitTargetElementLoop( - *rng, GpuElementalIrEmitter(hlo_module_config_, module_, &b_, - GetNestedComputer()) - .MakeElementGenerator(rng, operand_to_generator))); - std::unique_ptr rng_thunk = std::move(thunk_sequence_->back()); - thunk_sequence_->pop_back(); + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( + *rng, + GpuElementalIrEmitter(hlo_module_config_, module_, &b_, + GetNestedComputer()) + .MakeElementGenerator(rng, operand_to_generator), + rng_thunk.get())); // Emit a kernel to increment the global state for Philox RNG algorithm. - thunk_sequence_->emplace_back( - BuildKernelThunk(rng, /*implements_whole_instruction=*/false)); - llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); std::unique_ptr increment_seed_thunk = - std::move(thunk_sequence_->back()); - thunk_sequence_->pop_back(); + BuildKernelThunk(rng, /*implements_whole_instruction=*/false); + llvm_ir::IncrementVariableForPhiloxRngState(1, module_, &b_); // Build the SequentialThunk for the RNG hlo. std::vector> thunks; thunks.reserve(2); thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( absl::make_unique(std::move(thunks), rng)); return Status::OK(); } +Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + + std::vector> thunks; + + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*operand); + auto destination_buffer = GetAllocationSlice(*scatter); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + } + + thunks.push_back( + BuildKernelThunk(scatter, + /*implements_whole_instruction=*/thunks.empty())); + + TF_RETURN_IF_ERROR( + EmitScatter(thunks.back().get(), scatter, + /*scatter_indices_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*scatter_indices, *scatter) + .EmitReadArrayElement(index, &b_, "scatter_index"); + }, + /*updates_gen=*/ + [=](const IrArray::Index& index) { + return GetIrArray(*updates, *scatter) + .EmitReadArrayElement(index, &b_, "update"); + })); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), scatter)); + } + + return Status::OK(); +} + +Status IrEmitterUnnested::EmitScatter( + Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + + auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { + std::vector raw_window_multidim; + std::vector input_scatter_multidim; + std::vector raw_window_bounds; + + // Partition the index into window indices and scatter indices. + for (int64 i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates->shape().dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + dim_numbers.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64 raw_window_multidim_idx = 0; + std::vector input_window_multidim; + std::vector input_window_bounds; + for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape = scatter_indices->shape(); + if (dim_numbers.index_vector_dim() == + ShapeUtil::Rank(scatter_indices_shape)) { + scatter_indices_shape.add_dimensions(1); + scatter_indices_shape.mutable_layout()->add_minor_to_major( + dim_numbers.index_vector_dim()); + } + + // Now load the indices corresponding to the current window from + // scatter_indices. + llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, + index.GetType()); + raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + llvm::Value* is_in_bounds = b_.getTrue(); + for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_index[dim_numbers.index_vector_dim()] = + raw_scatter_index_index.GetConstantWithIndexType(i); + + int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + TF_ASSIGN_OR_RETURN( + llvm::Value* const loaded_scatter_index, + scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( + scatter_indices_shape, scatter_indices->shape(), &b_))); + // And add the index to our window index. This yields the output index. + llvm::Value* casted_scatter_index = + IntCast(loaded_scatter_index, index.GetType(), + /*isSigned=*/true); + llvm::Value* dim_offset = + Add(input_window_multidim[operand_dim], casted_scatter_index); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64 max_index = operand->shape().dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = index >= 0 && index < dim_size-window_size+1 + // --> index u< dim_size-window_size+1 + is_in_bounds = + And(is_in_bounds, ICmpULT(casted_scatter_index, + index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index(input_window_multidim, + index.GetType()); + HloInstruction* output_hlo = + scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter; + llvm::Value* output_address = + GetIrArray(*output_hlo, *output_hlo) + .EmitArrayElementAddress(input_window_index, &b_); + llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType( + updates->shape().element_type(), module_)); + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); + Store(input_ir_value, input_address); + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + updates->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, thunk, + ir_emitter_context_->llvm_module()); + + return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + launch_dimensions, &b_) + .EmitLoop(IrName(scatter), + GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), + &b_)); +} + Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { - thunk_sequence_->push_back( - BuildKernelThunk(select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleSelect(select); } Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { std::vector> thunks; - auto keys = sort->operand(0); - auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; - ShapeIndex keys_shape_index({}); - ShapeIndex values_shape_index({}); - if (values != nullptr) { - keys_shape_index = ShapeIndex({0}); - values_shape_index = ShapeIndex({1}); - } - auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); - auto values_destination = GetAllocationSlice(*sort, values_shape_index); - - if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*keys), - /*destination_buffer=*/keys_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); - } - if (values != nullptr && values_destination != GetAllocationSlice(*values)) { - // TODO(b/26783907): Figure out why we never seem to share buffers for - // key/value sort. - thunks.push_back(absl::make_unique( - /*source_address=*/GetAllocationSlice(*values), - /*destination_buffer=*/values_destination, - /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); + Shape keys_shape = sort->operand(0)->shape(); + int64 dimension_to_sort = sort->dimensions(0); + // In case there is a 'values' parameter that is a iota, we take note and use + // it later to ensure a stable sort. Otherwise, we don't guarantee a stable + // sort. + int64 iota_values_parameter_index = -1; + for (int64 i = 0; i < sort->operand_count(); ++i) { + if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota && + ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) && + Cast(sort->operand(i))->iota_dimension() == + dimension_to_sort) { + iota_values_parameter_index = i; + } + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + // We assume that the layout of all involved operands and outputs is the + // same. + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, + sort->operand(i)->shape())); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + + // If possible, we share buffers. If that is not possible, we need to copy + // the values, because the emitter does the sorting in-place. + auto destination_buffer = GetAllocationSlice(*sort, shape_index); + auto source_address = GetAllocationSlice(*sort->operand(i)); + if (destination_buffer != source_address) { + // TODO(b/26783907): Figure out why we never seem to share buffers for + // key/value sort. + thunks.push_back(absl::make_unique( + /*source_address=*/source_address, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()), + nullptr)); + } } - int64 dimension_to_sort = sort->dimensions(0); - int64 dimension_to_sort_bound = keys->shape().dimensions(dimension_to_sort); + uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - auto index_type = b_.getInt64Ty(); + CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); + CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); // Naive C++ code for the outer loops: // @@ -2009,41 +2226,128 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // } // } // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - + // This follows the alternative representation of the algorithm described on + // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter + // + // Each mask specifies how to derive from one position in the array the + // position with which it should be compared (we calculate the xor of the + // position with the mask). + // As an optimization, we can move the 'mask' loop to inside the + // sorting/comparison loop if the comparisons happen within a small block of + // the array. To make this work, we collect all consecutive masks that are + // smaller than our chosen power of 2 tile size, and pass them to SortInPlace. + // Each thread then processes one tile of data. + + const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages); + + // If we cannot combine several xor masks together, we don't use tiling, so we + // calculate the standard launch dimensions for the shape. However we only + // need to iterate through ~half of the dimension to sort (rounded up to the + // next highest power of 2), because each iteration compares one pair of + // elements. + Shape standard_iteration_shape = keys_shape; + uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); + standard_iteration_shape.set_dimensions(dimension_to_sort, + standard_num_iterations_in_sort_dim); + LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( + standard_iteration_shape, ir_emitter_context_->device_description()); + + // Calculate the launch dimensions for the case where we use tiling. We split + // the dimension that should be sorted into tiles of size 'kTileSize'. This + // means we first need to round 'dimension_to_sort_bound' up to be a multiple + // of the tile size. + int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize); + Shape iteration_shape = keys_shape; + + // We iterate through the element pairs that should be compared. + uint64 num_iterations_in_sort_dim = rounded_bound / 2; + iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim); + uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape); + + // For correctness reasons we need exactly 'kTileSize' / 2 many threads per + // block. Each thread is responsible for copying exactly two adjacent elements + // into shared memory, and then does a comparison of two possibly different + // elements taken from shared memory. + const uint64 kThreadsPerBlock = kTileSize / 2; + + // Check whether we should use any tiling. We might not be able to use it if + // we have not enough threads, or not enough shared memory. Also it does not + // give a speedup if the tile size is < 128. + int64 total_shared_memory_needed = 0; + for (int64 i = 0; i < sort->operand_count(); ++i) { + total_shared_memory_needed += + kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type()); + } + bool no_tiling = + kTileSize < 128 || + kThreadsPerBlock > + ir_emitter_context_->device_description().threads_per_block_limit() || + total_shared_memory_needed > + ir_emitter_context_->device_description().shared_memory_per_block(); + + uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); + LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + + auto emit_kernel = [&](absl::Span xor_masks) { + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + LaunchDimensions launch_dimensions = xor_masks.size() > 1 + ? tiled_launch_dimensions + : standard_launch_dimensions; + UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), + ir_emitter_context_->llvm_module()); + IrArray keys_array; + std::vector values_arrays; + values_arrays.reserve(sort->operand_count() - 1); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + if (i == 0) { + keys_array = GetIrArray(*sort, *sort, shape_index); + } else { + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + } + } + return llvm_ir::EmitSortInPlace( + dimension_to_sort, keys_array, values_arrays, + iota_values_parameter_index, IrName(sort), xor_masks, &b_, + launch_dimensions, + xor_masks.size() > 1 ? num_iterations_in_sort_dim + : standard_num_iterations_in_sort_dim, + kTileSize); + }; + std::vector xor_masks; for (int64 stage = 0; stage < num_stages; ++stage) { for (int64 mask = stage; mask >= 0; --mask) { - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - keys->shape(), ir_emitter_context_->device_description()); - UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), - ir_emitter_context_->llvm_module()); - - llvm::Value* xor_mask; + int64 xor_mask; if (mask == stage) { - xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1); + xor_mask = (1LL << (stage + 1)) - 1; } else { - xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask); + xor_mask = 1LL << mask; + } + if (xor_mask >= kTileSize || no_tiling) { + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + xor_masks.clear(); + } + TF_RETURN_IF_ERROR(emit_kernel({xor_mask})); + } else { + xor_masks.push_back(xor_mask); } - - TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( - dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? absl::make_optional( - GetIrArray(*sort, *sort, values_shape_index)) - : absl::nullopt, - IrName(sort), xor_mask, &b_, &launch_dimensions)); } } + if (!xor_masks.empty()) { + TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); + } - thunk_sequence_->emplace_back( + AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); return Status::OK(); } Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { - thunk_sequence_->push_back( + AddThunkToThunkSequence( BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); return IrEmitter::HandleTupleSelect(tuple_select); } @@ -2065,7 +2369,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(absl::make_unique( + AddThunkToThunkSequence(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2089,22 +2393,22 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { // Output a tuple of the buffers above. thunks.push_back(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); - thunk_sequence_->push_back( + AddThunkToThunkSequence( absl::make_unique(std::move(thunks), crs)); return Status::OK(); } -Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) { +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { - thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); + AddThunkToThunkSequence(BuildInfeedThunk(infeed)); return Status::OK(); } Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { - thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed)); + AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); return Status::OK(); } @@ -2413,28 +2717,43 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. 1.0, // alpha. - inst); + 0.0, // beta. + inst, /*implements_whole_instruction=*/true); } if (inst->opcode() == HloOpcode::kFusion) { CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - if (alpha->opcode() == HloOpcode::kBroadcast) { - alpha = alpha->operand(0); - } - if (alpha->opcode() == HloOpcode::kParameter) { - alpha = inst->operand(alpha->parameter_number()); - } - // TODO(b/74185543): Remove the following if block once we support fusion - // with a non-constant as well. Then we will just always use the constant - // on the device. - if (alpha->opcode() == HloOpcode::kCopy) { - alpha = alpha->operand(0); + const HloInstruction* output_fused_op = inst->fused_expression_root(); + + double alpha_value = 1.0; + const HloInstruction* bias = nullptr; + const HloInstruction* dot = output_fused_op->operand(0); + if (output_fused_op->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = output_fused_op->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); + } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + alpha_value = GetScalarConstantAsDouble(alpha->literal()); + } else { + // Fused bias add. + CHECK_EQ(output_fused_op->opcode(), HloOpcode::kAdd); + bias = output_fused_op->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, bias); + } + bias = inst->operand(bias->parameter_number()); } DCHECK(dot->opcode() == HloOpcode::kDot); @@ -2447,15 +2766,38 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); + // The bias is passed inside the output buffer. If those buffers are shared + // we can just use it, otherwise copy the bias values into the output buffer + // first. + if (bias != nullptr && + GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { + std::vector> thunks; + thunks.push_back(absl::make_unique( + /*source_buffer=*/GetAllocationSlice(*bias), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); + thunks.push_back(absl::make_unique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + alpha_value, // alpha. + 1.0, // beta. + inst, /*implements_whole_instruction=*/false)); + return absl::make_unique(std::move(thunks), inst); + } return absl::make_unique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - GetScalarConstantAsDouble(alpha->literal()), // alpha. - inst); + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + alpha_value, // alpha. + bias != nullptr ? 1.0 : 0.0, // beta. + inst, /*implements_whole_instruction=*/true); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2564,15 +2906,12 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( if (fused) { // If init_value was fused into this reduce we have to generate it first. - std::vector parameter_arrays; - for (HloInstruction* operand : hlo->operands()) { - parameter_arrays.push_back(GetIrArray(*operand, *hlo)); - } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elemental_emitter); TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); TF_RETURN_IF_ERROR( ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), @@ -2777,8 +3116,18 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_)); } - // For multioutput fusion, we need to emit each operand and the root. + // Emit the tuple pointers in one thread. We could do this at any point in + // the kernel, but we do it at the beginning in the hopes of reducing register + // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the + // kernel *anyway*. std::vector output_arrays = ConstructIrArrayForOutputs(hlo); + TF_RETURN_IF_ERROR( + KernelSupportLibrary(&b_).If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); + })); + + // For multioutput fusion, we need to emit each operand and the root. TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2787,17 +3136,25 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( &hlo, launch_dimensions.launch_bound(), &b_))); b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); - return Status::OK(); } Status IrEmitterUnnested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - CHECK_EQ(Thunk::Kind::kKernel, LastThunk()->kind()); - return EmitTargetElementLoopInThunk(hlo, element_generator, - static_cast(LastThunk())); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo.IsElementwise() || hlo.opcode() == HloOpcode::kFusion) { + unroll_factor = ComputeMaxUnrollFactor(&hlo); + } + + std::unique_ptr kernel_thunk = BuildKernelThunk( + &hlo, /*implements_whole_instruction=*/true, unroll_factor); + Status emit_status = + EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get()); + thunk_sequence_->emplace_back(std::move(kernel_thunk)); + + return emit_status; } std::vector IrEmitterUnnested::ConstructIrArrayForInputs( @@ -2810,31 +3167,6 @@ std::vector IrEmitterUnnested::ConstructIrArrayForInputs( return param_arrays; } -int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_in_reduced_shape_arrays->reserve(num_outputs); - output_reduced_shapes->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(), - reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[i].CastToShape((*output_reduced_shapes)[i], &b_)); - } - } else { - output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - hlo.shape().element_type(), reduced_output_dims)); - output_in_reduced_shape_arrays->push_back( - output_arrays[0].CastToShape((*output_reduced_shapes)[0], &b_)); - } - return num_outputs; -} int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, @@ -2863,338 +3195,531 @@ int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( namespace { -// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the -// thread lives within a square tile of size tile_size (so thread blocks are of -// size tile_size * tile_size). -std::tuple CalculateYXCoordinateWithinTile( - llvm::IRBuilder<>* builder, llvm::Value* tile_size, - int64 threads_per_tile) { - // Calculate the starting element coordinate within a tile for the current - // thread, (y, x) from thread_id. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_tile, - llvm::cast(thread_id)); - thread_id = builder->CreateIntCast(thread_id, tile_size->getType(), - /*isSigned=*/true, "thread.id.x"); - auto x = builder->CreateURem(thread_id, tile_size); - auto y = builder->CreateUDiv(thread_id, tile_size); - return std::make_tuple(y, x); -} - -// Reads block_idx.x, casts it to type index_ty, and adds the assumption that -// it's in the range [0, num_blocks]. -llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty, - int64 num_blocks) { - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id)); - return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true, - "block.id.x"); -} - -// Emits code to process up to (tile_size/num_rows) elements in a tile, given -// `emit_elem_function` is the function to emit code to process one element, `y` -// and `x` are the coordinates for the first element to process, and `index` is -// the index for the origin of the tile. Emits bounds check to ensure that each -// processed element is within the boundary defined by `tile_width` and -// `tile_height`. +void EmitFullTile(const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, + llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); + for (int64 i = 0; i < tile_size_y; i += num_threads_y) { + IrArray::Index source_idx_y = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, i), + KernelMappingScheme::DimY, builder); + llvm::Value* y_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, i), y); + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + source_idx_y.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + emit_elem_function(source_idx, y_loc, x_loc); + } + } +} + +void EmitPartialTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + llvm::Type* index_ty, + const std::function& emit_elem_function) { + int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); + int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + + for (int64 j = 0; j < tile_size_x; j += num_threads_x) { + IrArray::Index source_idx = + tile_origin_index.AddOffsetToDim(llvm::ConstantInt::get(index_ty, j), + KernelMappingScheme::DimX, builder); + llvm::Value* x_loc = + builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x); + + ksl->IfReturnVoid( + "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] { + // tile_height_bound = + // ceil(tile_height / num_threads_y) * num_threads_y + llvm::Value* ceiling_of_ratio = builder->CreateUDiv( + builder->CreateAdd(tile_height, llvm::ConstantInt::get( + index_ty, num_threads_y - 1)), + llvm::ConstantInt::get(index_ty, num_threads_y)); + llvm::Value* tile_height_bound = builder->CreateMul( + ceiling_of_ratio, + llvm::ConstantInt::get(index_ty, num_threads_y)); + ksl->ForReturnVoid( + loop_name, /*start=*/llvm::ConstantInt::get(index_ty, 0), + /*end=*/tile_height_bound, + /*step=*/llvm::ConstantInt::get(index_ty, num_threads_y), + [&](llvm::Value* y_indvar) { + llvm::Value* y_loc = builder->CreateAdd(y_indvar, y); + ksl->IfReturnVoid( + "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), + [&] { + emit_elem_function( + source_idx.AddOffsetToDim( + y_indvar, KernelMappingScheme::DimY, builder), + y_loc, x_loc); + }); + }); + }); + } +} + +// Emits code to process up to +// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile, +// given `emit_elem_function` is the function to emit code to process one +// element, `y` and `x` are the intra-tile coordinates for the first element +// to process, and `index` is the index for the origin of the tile. Information +// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits +// bounds check to ensure that each processed element is within the boundary +// defined by `tile_width` and `tile_height`. void EmitTiledElementalCodeWithBoundsCheck( - int64 tile_size, int64 num_rows, const IrArray::Index& index, - const string& loop_name, KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); + int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); - // Emits a constant value with index type. - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = builder->CreateAdd(index[dim], addend); - return index; - }; - - auto emit_full_tile = [&] { - for (int64 i = 0; i < tile_size; i += num_rows) { - auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1); - auto y_loc = builder->CreateAdd(index_typed_constant(i), y); - emit_elem_function(source_idx, y_loc); - } - }; - auto emit_last_row = [&] { - ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] { - // tile_height_upper_bound = - // ceil(tile_height / num_rows) * num_rows - auto tile_height_upper_bound = builder->CreateMul( - builder->CreateUDiv( - builder->CreateAdd(tile_height, - index_typed_constant(num_rows - 1)), - index_typed_constant(num_rows)), - index_typed_constant(num_rows)); - ksl->ForReturnVoid( - loop_name, /*start=*/index_typed_constant(0), - /*end=*/tile_height_upper_bound, - /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) { - auto y_loc = builder->CreateAdd(y_indvar, y); - ksl->IfReturnVoid( - "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] { - emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1), - y_loc); - }); - }); - }); - }; ksl->IfReturnVoid( "full_tile", builder->CreateAnd( - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width), - builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)), - emit_full_tile, emit_last_row); + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x), + tile_width), + builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_y), + tile_height)), + [&] { + EmitFullTile(mapping_scheme, tile_origin_index, builder, y, x, index_ty, + emit_elem_function); + }, + [&] { + EmitPartialTile(mapping_scheme, tile_origin_index, loop_name, ksl, + builder, y, x, tile_height, tile_width, index_ty, + emit_elem_function); + }); } } // namespace -// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose -// algorithm to improve the memory access patterns for the input parameters -// which have a shape that is a 0-2-1 transpose of the output tensors. -// -// For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape of -// three components 0-1-2 in the order major to minor. The x- and y- dimensions -// of the tensors are tiled in square tiles of edge length `kTileSize`. Each -// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each -// thread copies kTileSize/kNumRows elements from the input to a shared memory -// tile, then the otherwise "regular hlo kernel" reads from the shared memory -// instead of the original input. +// Emits code to process a tensor element in a tile for the given kCopy HLO that +// performs a 0-2-1 transpose. // -// This is similar to the following CUDA algorithm in TensorFlow: -// https://goo.gl/MStRV6. -// -// `kTileSize` should usually be same as warp size. We currently choose 32 for -// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient -// to launch fewer blocks so each transposes many tiles. -LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, absl::Span reduced_output_dims, - absl::Span tiled_param_ids) { - // Parameters for the tiling algorithm. - constexpr int64 kTileSize = 32; - constexpr int64 kNumRows = 4; - constexpr int64 kThreadsPerTile = kTileSize * kNumRows; +// index: The index for the first output element in the normalized tensor. The +// normalized tensor is the resulting tensor after collapsing contiguous +// dimensions that play the same role in the transpose. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +// kernel_info: Other information to support the kernel code generation. +void IrEmitterUnnested::EmitTileElementForCopy( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); + // TODO(jlebar): Add AA metadata to this load. + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(tiled_param_info->GetBufferForParameter(0), + {b_.getInt64(0), x_loc, y_loc}), + "output_element"); + llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); + Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( + hlo->shape().element_type(), + kernel_info->GetKernelMappingScheme()->GetDimensionsInElements()); + // When the output_reduced_shape is a 0-2-1 transpose of the input shape, + // the 0-2-1 transpose is achieved through EmitWriteArrayElement. + output_array.CastToShape(output_reduced_shape, &b_) + .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_); +} - // Construct IrArrays for the inputs and outputs. +// Emits code to process a tensor element in a tile for the given kLoop fusion +// HLO containing parameters that are 0-2-1 transpose of its outputs. +// +// index: The index for the first output element in the normalized tensor, that +// is the resulting tensor after collapsing contiguous dimensions that play +// the same role in the transpose. +// kernel_info: Other information to support the kernel code generation. +// y_loc: The y coordinate within a tile. +// x_loc: The x coordinate within a tile. +void IrEmitterUnnested::EmitTileElementForFusion( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc) { + llvm_ir::TiledParameterInfo* tiled_param_info = + kernel_info->GetTiledParameterInfo(); std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); - int64 num_outputs = output_arrays.size(); - std::vector param_arrays = ConstructIrArrayForInputs(*hlo); - int64 num_params = param_arrays.size(); + GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, + GetNestedComputer()); + FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), + &elem_emitter); + tiled_param_info->set_y(y_loc); + tiled_param_info->set_x(x_loc); + fused_emitter.SetTiledParameterInfo(tiled_param_info); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); + IrArray::Index untiled_index = + kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex( + index, output_arrays[0].GetShape()); + const llvm_ir::ElementGenerator& output_generator = + fused_emitter.GetRootGenerator(); + llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); + if (hlo->IsMultiOutputFusion()) { + DCHECK(output_value->getType()->isStructTy()); + DCHECK_EQ(output_value->getType()->getStructNumElements(), + output_arrays.size()); + for (int64 i = 0; i < output_arrays.size(); ++i) { + output_arrays[i].EmitWriteArrayElement( + untiled_index, ExtractValue(output_value, i), &b_); + } + } else { + output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_); + } +} +// Emits a block of tiles, given a function object to emit one tile. +void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, + llvm::Type* index_ty) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); + absl::Span dims_in_block = + mapping_scheme->GetDimensionsInBlocks(); + absl::Span block_sizes = mapping_scheme->GetBlockSizes(); + auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + // Emit all the tiles for a given dimension in a tile block. + auto emit_tiles_for_block_dim = + [&](const string& loop_name, const IrArray::Index& starting_tile, + int dim_id, + const std::function + emit_next_block_dim) { + if (block_sizes[dim_id] == 1) { + emit_next_block_dim(starting_tile); + } else { + llvm::Value* starting_tile_index_for_dim = starting_tile[dim_id]; + llvm::Value* block_size_for_dim = + index_typed_constant(block_sizes[dim_id]); + llvm::Value* block_id_for_dim = + b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim); + llvm::Value* last_block_for_dim = + index_typed_constant(dims_in_block[dim_id] - 1); + llvm::Value* last_block_size_for_dim = index_typed_constant( + dims_in_tile[dim_id] - + (dims_in_block[dim_id] - 1) * block_sizes[dim_id]); + llvm::Value* num_tiles_in_block = + Select(ICmpEQ(last_block_for_dim, block_id_for_dim), + last_block_size_for_dim, block_size_for_dim); + + ksl.ForReturnVoid( + loop_name, + /*start=*/index_typed_constant(0), + /*end=*/num_tiles_in_block, + /*step=*/1, [&](llvm::Value* block_dim_induction_var) { + IrArray::Index tile_index = starting_tile.AddOffsetToDim( + block_dim_induction_var, dim_id, &b_); + emit_next_block_dim(tile_index); + }); + } + }; + + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + const bool block_contains_multi_tiles = + mapping_scheme->GetNumberOfTilesInOneBlock() > 1; + + // Emit the tile with a given tile_index, by calculating the tight bounds for + // each dimension of the tile and then calling emit_one_tile. + auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { + std::vector output_tile_bounds(3); + for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; + ++i) { + int64 tile_size_for_dim = mapping_scheme->GetTileSizeForDimension(i); + // Only last row or column may not have full size. + llvm::Value* is_last_row = + ICmpEQ(tile_index[i], index_typed_constant(dims_in_tile[i] - 1)); + int64 partial_row_size = + reduced_dims[i] - (dims_in_tile[i] - 1) * tile_size_for_dim; + output_tile_bounds[i] = + Select(is_last_row, index_typed_constant(partial_row_size), + index_typed_constant(tile_size_for_dim), "tile_bound"); + } + + IrArray::Index tile_origin = + mapping_scheme->GetElementIndexForTileOrigin(tile_index); + emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles); + }; + + const IrArray::Index starting_block = + mapping_scheme->EmitBlockIndex(index_ty); + const IrArray::Index starting_tile_for_dim_z = + mapping_scheme->GetTileIndexForBlockOrigin(starting_block); + + // Emit the three dimensional block of tiles. + emit_tiles_for_block_dim( + "block_dim_z", starting_tile_for_dim_z, KernelMappingScheme::DimZ, + [&](const IrArray::Index& starting_tile_for_dim_y) { + emit_tiles_for_block_dim( + "block_dim_y", starting_tile_for_dim_y, KernelMappingScheme::DimY, + [&](const IrArray::Index& starting_tile_for_dim_x) { + emit_tiles_for_block_dim("block_dim_x", starting_tile_for_dim_x, + KernelMappingScheme::DimX, + emit_one_tile_for_tile_index); + }); + }); +} + +// Emits a kernel for the hlo instruction using the given kernel mapping scheme. +// +// unnested_hlo: The unnested hlo instruction for which the kernel is generated. +// Currently, these hlo instructions are supported: kLoop fusion, kCopy. +// tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of +// other tensors with the same dimensions and need to be tiled and tranposed. +// mapping_scheme: The tiling scheme to use. +// kernel_generator: Contains function objects for code generation, such as +// element generator, block prologue and epilogue generators. +// kernel_info: Represent other information to support the code generation +// of the tiled kernel for the hlo. +LaunchDimensions IrEmitterUnnested::EmitKernel( + HloInstruction* unnested_hlo, absl::Span tiled_param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info) { + KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); + + std::vector param_arrays = ConstructIrArrayForInputs(*unnested_hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); for (int64 id : tiled_param_ids) { - const HloInstruction* param = hlo->operand(id); - // Add 1 to the minor dimension to reduce shared memory bank conflicts. - llvm::Type* tile_type = llvm::ArrayType::get( - llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( - param->shape().element_type(), module_), - kTileSize + 1), - kTileSize); - const int kNVPTXSharedMemoryAddrSpace = 3; - auto* tile_base_ptr = new llvm::GlobalVariable( - *b_.GetInsertBlock()->getParent()->getParent(), tile_type, - /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::UndefValue::get(tile_type), - llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr, - llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); - param_shmem_buffers[id] = tile_base_ptr; + const HloInstruction* param = unnested_hlo->operand(id); + param_shmem_buffers[id] = + mapping_scheme->GetSharedMemoryBufferForElementType( + llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), + module_), + IrName(unnested_hlo, StrCat("tile", id))); VLOG(3) << "Added shmem buffer for parameter " << id << ": " - << llvm_ir::DumpToString(*tile_base_ptr); - } - - // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result - // for the purpose of tiling. Calculate the logical output dimensions in the - // tile from the reduced output dimensions. - std::vector output_dims_in_tiles = std::vector( - reduced_output_dims.begin(), reduced_output_dims.end()); - CHECK_EQ(output_dims_in_tiles.size(), 3); - for (int i = 1; i < 3; ++i) { - output_dims_in_tiles[i] = - CeilOfRatio(output_dims_in_tiles[i], kTileSize); + << llvm_ir::DumpToString(*param_shmem_buffers[id]); } - const int64 num_tiles = - absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); - LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); - llvm::Type* index_ty = - GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_); + CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0); + LaunchDimensions launch_dimensions = LaunchDimensions( + mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { return llvm::ConstantInt::get(index_ty, c); }; - // Cast each output IrArray to its corresponding reduced shape and keep the - // reduced shape live during IR emission. - std::vector output_in_reduced_shape_arrays; - std::vector output_reduced_shapes; - CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes, - &output_in_reduced_shape_arrays), - num_outputs); + // For multioutput fusion, one thread needs to output a tuple with pointers to + // all the individual outputs. We could do this at any point in the kernel, + // but we do it at the beginning in the hopes of reducing register pressure, + // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (unnested_hlo->IsMultiOutputFusion()) { + TF_CHECK_OK(KernelSupportLibrary(&b_).If( + "emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), + ConstructIrArrayForOutputs(*unnested_hlo), &b_, + module_); + return Status::OK(); + })); + } // For each tiled parameter, cast its input IrArray to the corresponding // reduced shape and keep the reduced shape live during IR emission. std::vector param_in_reduced_shape_arrays; std::vector param_reduced_shapes; - CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape( - *hlo, param_arrays, param_shmem_buffers, reduced_output_dims, - ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays), - num_params); + absl::Span reduced_dims = + mapping_scheme->GetDimensionsInElements(); + int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape( + *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims, + ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays); + DCHECK_EQ(num_shapes, num_params); // Calculate the starting element coordinate within a tile for the current // thread, (y, x) from thread_id. llvm::Value* x; llvm::Value* y; - std::tie(y, x) = CalculateYXCoordinateWithinTile( - &b_, index_typed_constant(kTileSize), kThreadsPerTile); - - // Calculate the index for the current output tile from block_id. - const IrArray::Index output_tile_index( - GetBlockIdx(&b_, index_ty, num_tiles), - ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/, - output_dims_in_tiles), - &b_); - - // Output tile origin is the index for the first element of the current output - // tile. - const IrArray::Index output_tile_origin = [&] { - IrArray::Index index = output_tile_index; - for (int i = 1; i < 3; ++i) { - index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); - } - return index; - }(); + std::tie(y, x) = mapping_scheme->EmitThreadYXCoordinate(index_ty); - // Calculate the input tile origin from the output tile origin. - const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim())); - - // Calculate the current output tile bounds in each of the logical dimensions. - std::vector output_tile_bounds(3); - for (int i = 1; i < 3; ++i) { - // Only last row or column may not have full size. - output_tile_bounds[i] = - Select(ICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); - } + kernel_info->SetLaneId( + mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x + : nullptr); KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. auto emit_tiled_elemental_code_with_bounds_check = [&](const IrArray::Index& index, const string& loop_name, - llvm::Value* tile_width, llvm::Value* tile_height, - const std::function& - emit_elem_function) { - EmitTiledElementalCodeWithBoundsCheck( - kTileSize, kNumRows, index, loop_name, &ksl, &b_, y, x, tile_width, - tile_height, emit_elem_function); + llvm::Value* tile_height, llvm::Value* tile_width, + const std::function& emit_elem_function) { + EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, + &ksl, &b_, y, x, tile_height, + tile_width, emit_elem_function); }; - // Adds `addend` to the given `dim` of `index`. - auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = Add(index[dim], addend); - return index; - }; - const IrArray::Index input_index = - offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1); - - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - emit_tiled_elemental_code_with_bounds_check( - input_index, "input", output_tile_bounds[1], output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); - } - }); + auto emit_one_tile = [&](const IrArray::Index& output_tile_origin, + absl::Span output_tile_bounds, + bool block_contains_multi_tiles) { + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, output_tile_origin.multidim())); + + const IrArray::Index input_index = + input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we are + // reading a transposed tile. + emit_tiled_elemental_code_with_bounds_check( + input_index, "input", output_tile_bounds[2], output_tile_bounds[1], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id]; + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + // TODO(jlebar): Add AA metadata to this store. Tile buffers are + // global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); + } + }); - // Wait for all threads to reach this point, lest we copy a value from tile to - // output before the other thread copies it from input to tile. - // This is `__syncthreads` in CUDA. - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + // If shared memory transpose is needed, wait for all threads to reach this + // point, lest we copy a value from tile to output before the other thread + // copies it from input to tile. This is `__syncthreads` in CUDA. + if (!tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } - llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); + kernel_info->SetTiledParamInfo(&tiled_param_info); - const IrArray::Index output_index = - offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1); + const IrArray::Index output_index = + output_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_) + .AddOffsetToDim(y, KernelMappingScheme::DimY, &b_); - // Write to output[index] by emitting code like normal, except that values for - // the tiled parameters are read from the shmem buffers. - if (hlo->opcode() == HloOpcode::kCopy) { - emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = - Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, load_from_shmem_buffer, &b_); - }); - } else { - CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + // Write to output[index] by emitting code like normal, except that values + // for the tiled parameters are read from the shmem buffers. emit_tiled_elemental_code_with_bounds_check( - output_index, "output", output_tile_bounds[2], output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc) { - GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, - GetNestedComputer()); - FusedIrEmitter fused_emitter(param_arrays, &elem_emitter); - tiled_param_info.set_y(y_loc); - fused_emitter.SetTiledParameterInfo(&tiled_param_info); - TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); - IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex( - index, output_reduced_shapes[0], output_arrays[0].GetShape(), - &b_); - const llvm_ir::ElementGenerator& output_generator = - fused_emitter.GetRootGenerator(); - llvm::Value* output_value = - output_generator(untiled_index).ValueOrDie(); - if (hlo->IsMultiOutputFusion()) { - CHECK(output_value->getType()->isStructTy()); - CHECK_EQ(output_value->getType()->getStructNumElements(), - output_in_reduced_shape_arrays.size()); - for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { - output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, ExtractValue(output_value, i), &b_); - } - } else { - output_in_reduced_shape_arrays[0].EmitWriteArrayElement( - index, output_value, &b_); - } + output_index, "output", output_tile_bounds[1], output_tile_bounds[2], + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc) { + kernel_generator.GetTileElementGenerator()(unnested_hlo, index, + kernel_info, y_loc, x_loc); }); + // If a tile block contains multiple tiles and shared memory buffers are + // used, we need to wait for all threads to finish using the shared memory + // buffer for the current tile before we move on to process the next tile + // and overwrite the shared memory buffers. + if (block_contains_multi_tiles && !tiled_param_ids.empty()) { + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_); + } + }; + + const BlockPrologueGenerator& block_prologue_generator = + kernel_generator.GetBlockPrologueGenerator(); + if (block_prologue_generator) { + block_prologue_generator(unnested_hlo, kernel_info); } - // For multioutput fusion, emit a tuple with all the individual outputs. - if (hlo->IsMultiOutputFusion()) { - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); + EmitBlock(std::move(emit_one_tile), kernel_info, ksl, index_ty); + + const BlockEpilogueGenerator& block_epilogue_generator = + kernel_generator.GetBlockEpilogueGenerator(); + if (block_epilogue_generator) { + block_epilogue_generator(unnested_hlo, kernel_info); } return launch_dimensions; } +// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose +// algorithm to improve the memory access patterns for the input parameters +// with a shape that is a 0-2-1 transpose of the output tensor shape. +// +// For the purpose of tiling, the output tensors have a logical shape of three +// components 0-2-1 while the relevant input parameters have a logical shape +// of three components 0-1-2 in the order major to minor. The x- and y- +// dimensions of the tensors are tiled in square tiles with an edge length +// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads +// transposes one tile: each thread copies kTileSize/kNumRows elements from +// the input to a shared memory tile, then the otherwise "regular HLO kernel" +// reads from the shared memory instead of the original input. +// +// This is similar to the following CUDA algorithm in TensorFlow: +// https://goo.gl/MStRV6. +// +// `kTileSize` should usually be same as warp size. We currently choose 32 for +// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. +// +// TODO(b/33320379): Here each block transposes 1 tile. It may be more +// efficient to launch fewer blocks so each transposes many tiles. +LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { + constexpr int kNumRows = 4; + KernelMappingScheme mapping_scheme( + reduced_output_dims, /*tile_size_y=*/kWarpSize, + /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1}, + /*num_threads_y=*/kNumRows, + /*num_threads_x=*/kWarpSize, &b_); + TileElementGenerator element_generator; + if (hlo->opcode() == HloOpcode::kCopy) { + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc); + }; + } else { + DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + element_generator = [&](HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc) { + EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc); + }; + } + KernelCodegenInfo kernel_info(&mapping_scheme); + KernelCodeGenerator kernel_generator(std::move(element_generator)); + return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info); +} + +namespace { +// Returns true to indicate it is safe to use the tile based shared memory +// transpose implementation to implement the kernel for the instruction. +// +// An instruction is not safe for such an implementation if it can change the +// element order of a tensor without changing the dimension of the tensor, and +// the instruction has a corresponding elemental_ir_emitter. +bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) { + auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) { + HloOpcode opcode = instr->opcode(); + CHECK_NE(opcode, HloOpcode::kFusion); + return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather); + }; + + if (hlo->opcode() == HloOpcode::kFusion) { + return absl::c_all_of(hlo->fused_instructions_computation()->instructions(), + is_safe_for_tile_based_transpose); + } + + return is_safe_for_tile_based_transpose(hlo); +} +} // namespace + bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { HloOpcode opcode = hlo->opcode(); CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy); @@ -3206,8 +3731,8 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { ? ShapeUtil::GetSubshape(hlo->shape(), {0}) : hlo->shape(); - // If the output_shape is reduced to 021 shape, find all the parameters of the - // hlo that are in the corresponding 012 shape. + // If the output_shape is reduced to 021 shape, find all the parameters of + // the HLO that are in the corresponding 012 shape. std::vector params_012; optional> reduced_dims_021; for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); @@ -3239,10 +3764,14 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { return false; } + if (!IsInstructionSafeForTileBasedTranspose(hlo)) { + return false; + } + // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the - // elements are of size 4 bytes), and CUDA has an architectural limit of 48kb - // shared memory per SM. (This is increased to 96kb in Volta, but we don't - // use this, in part because it eats into our L1 cache space.) + // elements are of size 4 bytes), and CUDA has an architectural limit of + // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we + // don't use this, in part because it eats into our L1 cache space.) // // For correctness we need to ensure that we don't make more than 48kb worth // of shmem tiles per block. And for performance, we'd probably like to use @@ -3250,9 +3779,9 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // gpu core. // // We say without benchmarks that we want at least 3 threads/block, - // corresponding to 3 shmem tiles if the elements are 32 bits wide. We choose - // which params get the shmem transpose treatment arbitrarily; it's not clear - // if there's a Right Choice. + // corresponding to 3 shmem tiles if the elements are 32 bits wide. We + // choose which params get the shmem transpose treatment arbitrarily; it's + // not clear if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use // shared memory in fusions. If in the future other fusible ops use shared @@ -3274,12 +3803,13 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { } VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); - thunk_sequence_->emplace_back( - BuildKernelThunk(hlo, /*implements_whole_instruction=*/true)); + std::unique_ptr kernel_thunk = + BuildKernelThunk(hlo, /*implements_whole_instruction=*/true); const LaunchDimensions launch_dimensions = EmitHlo021Tile(hlo, *reduced_dims_021, params_012); - UpdateLaunchDimensions(launch_dimensions, LastThunk(), + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); + AddThunkToThunkSequence(std::move(kernel_thunk)); return true; } @@ -3305,10 +3835,10 @@ Status IrEmitterUnnested::EmitConstantGlobals() { } // These globals will be looked up by name by GpuExecutable so we need to - // give them an external linkage. Not all of their uses are visible in the - // LLVM IR (e.g. TupleThunk) so we can't give then a linkage that merely - // preserves their names (like available_externally), we also need to ensure - // that they stick around even if they're "unused". + // give them an external linkage. Not all of their uses are visible in + // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that + // merely preserves their names (like available_externally), we also need + // to ensure that they stick around even if they're "unused". // // We may have to be more more clever here in the future if we notice that // we're keeping around too many globals because of their linkage. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index bd5db7205155dc6b15ddea069e172bbd8f419996..e09ed657a812be6ab4859a0e365a51c45a37bfed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" namespace xla { @@ -46,6 +48,94 @@ namespace gpu { // class IrEmitterUnnested : public IrEmitter { public: + // Parameter block_contains_multi_tiles indicates whether a tile block + // consists of multiple tiles or not. If the tile block contains only one + // tile, there is no need to use atomic operation to accumulate a local result + // to a global result to implement reduction. + using TileGenerator = + std::function output_tile_bounds, + bool block_contains_multi_tiles)>; + // KernelCodegenInfo records the common information to support the code + // generation for a kernel to process tensor elements by blocks. A block of + // tensor elements may contain one or multiple tiles. The code generators that + // generate code for tile elements or block prologue/epilogue refer to this + // class in their prototypes. If the implementations of such code generators + // require other information that are specific to the HLO instructions, the + // implementations need to define and use derived classes of this class. + class KernelCodegenInfo { + public: + explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) + : mapping_scheme_(mapping_scheme), + tiled_param_info_(nullptr), + lane_id_(nullptr) {} + + void SetLaneId(llvm::Value* v) { lane_id_ = v; } + void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { + CHECK_EQ(tiled_param_info_, nullptr); + tiled_param_info_ = tiled_param_info; + } + + llvm::Value* GetLaneId() const { return lane_id_; } + llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { + return mapping_scheme_; + } + llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { + return tiled_param_info_; + } + + private: + llvm_ir::KernelMappingScheme* mapping_scheme_; + llvm_ir::TiledParameterInfo* tiled_param_info_; + llvm::Value* lane_id_; + }; + + // A function object to prepare for the code generation for a tile block. + using BlockPrologueGenerator = + std::function; + // A function object to finalize the code generation for a tile block. + using BlockEpilogueGenerator = + std::function; + // A function object to generate code to process one element in a tile. + // + // hlo: the instruction for which the code is generated for. + // index: the index for the first output element of the current thread. + // y_loc: The y coordinate within a tile. + // x_loc: The x coordinate within a tile. + // kernel_info: Other information to support the kernel code generation. + using TileElementGenerator = std::function; + + // KernelCodeGenerator records the code generator objects that generate code + // for tile elements or tile block prologue/epilogue. + class KernelCodeGenerator { + public: + explicit KernelCodeGenerator( + TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator = {}, + BlockEpilogueGenerator block_epilogue_generator = {}) + : tile_element_generator_(std::move(tile_element_generator)), + block_prologue_generator_(std::move(block_prologue_generator)), + block_epilogue_generator_(std::move(block_epilogue_generator)) {} + + const TileElementGenerator& GetTileElementGenerator() const { + return tile_element_generator_; + } + const BlockPrologueGenerator& GetBlockPrologueGenerator() const { + return block_prologue_generator_; + } + const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { + return block_epilogue_generator_; + } + + private: + TileElementGenerator tile_element_generator_; + BlockPrologueGenerator block_prologue_generator_; + BlockEpilogueGenerator block_epilogue_generator_; + }; + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context); @@ -76,11 +166,12 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleAfterAll(HloInstruction* gen_token) override; + Status HandleAfterAll(HloInstruction* after_all) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -96,10 +187,10 @@ class IrEmitterUnnested : public IrEmitter { Status EmitConstantGlobals(); private: - // Builds the appropriate thunk for the instruction hlo and returns the owning - // pointer to it. The caller needs to make sure `inst` outlives the lifetime - // of the returned Thunk object. - std::unique_ptr BuildThunk(const HloInstruction* hlo); + // Add a owning Thunk object to the thunk sequence. + void AddThunkToThunkSequence(std::unique_ptr thunk) { + thunk_sequence_->emplace_back(std::move(thunk)); + } // Builds the prototype of the IR kernel for `inst` and adds it to the module. // This kernel takes as arguments pointers to the given buffer allocations. @@ -124,8 +215,8 @@ class IrEmitterUnnested : public IrEmitter { // [height x width], but can be bitcast to [height x width] with "height" // being the major dimension. Status EmitColumnReduction( - int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, + KernelThunk* kernel_thunk, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -139,8 +230,8 @@ class IrEmitterUnnested : public IrEmitter { // [depth x height x width], but can be bitcast to [depth x height x width] // with "depth" being the most major dimension. Status EmitRowReduction( - int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, + KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -150,7 +241,8 @@ class IrEmitterUnnested : public IrEmitter { // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( - HloInstruction* reduce, const Shape& input_shape, + KernelThunk* kernel_thunk, HloInstruction* reduce, + const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span reducers, @@ -175,7 +267,8 @@ class IrEmitterUnnested : public IrEmitter { // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, + KernelThunk* kernel_thunk, HloInstruction* reduce, + const Shape& input_shape, absl::Span input_gens, absl::Span init_value_gens, absl::Span dimensions_to_reduce, @@ -184,6 +277,14 @@ class IrEmitterUnnested : public IrEmitter { absl::Span> extra_output_gens); + // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in + // the process. `scatter` may be fused, scatter indices are taken from + // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is + // expected to have the operand values in it already. + Status EmitScatter(Thunk* thunk, HloInstruction* scatter, + const llvm_ir::ElementGenerator& scatter_indices_gen, + const llvm_ir::ElementGenerator& updates_gen); + // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. bool CheckAndEmitHloWithTile021(HloInstruction* hlo); @@ -193,22 +294,32 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids); + // Emits a kernel for an unnested HLO instruction. + LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, + absl::Span param_ids, + const KernelCodeGenerator& kernel_generator, + KernelCodegenInfo* kernel_info); + void EmitBlock(const TileGenerator& emit_one_tile, + const KernelCodegenInfo* kernel_info, + KernelSupportLibrary& ksl, llvm::Type* index_ty); + // Emits code to process a tensor element in a tile for the given kCopy HLO + // that performs a 0-2-1 transpose. + void EmitTileElementForCopy(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); + // Emits code to process a tensor element in a tile for the given kLoop fusion + // HLO containing parameters that are 0-2-1 transpose of its outputs. + void EmitTileElementForFusion(HloInstruction* hlo, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + llvm::Value* y_loc, llvm::Value* x_loc); // Generates the IrArray for each input of an hlo and returns a vector that // constains such IrArrays. std::vector ConstructIrArrayForInputs( const HloInstruction& hlo); - // For each output of the `hlo` instruction, constructs the reduced shape for - // the output with the given `reduced_output_dims` and cast the original - // output IrArray element in `output_arrays` to the reduced shape. Returns - // the number of outputs. - int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( - const HloInstruction& hlo, - const std::vector& output_arrays, - absl::Span reduced_output_dims, - std::vector* output_reduced_shapes, - std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in // `param_buffers` to find out whether the input has a reduced shape. If the // input has a reduced shape, constructs the reduced shape for the input and diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index 8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2..bd53b90b42d8e657a3ee58e7ca03fb60522aae28 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -177,13 +177,6 @@ std::unique_ptr GetTargetMachine( } TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); - llvm_ir::SetTargetOptions( - /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_gpu_enable_fast_math(), - &target_options); - - // Enable FMA synthesis. - target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. target_options.MCOptions.AsmVerbose = false; @@ -206,8 +199,7 @@ std::unique_ptr GetTargetMachine( } return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - Optional(RelocModel), Optional(CMModel), - codegen_opt_level)); + getRelocModel(), getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -401,8 +393,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, int32 opt_level = hlo_module_config.debug_options().xla_backend_optimization_level(); - CHECK_GE(opt_level, 2) - << "The XLA GPU backend doesn't support unoptimized code generation"; + if (opt_level < 2) { + LOG(ERROR) << std::string(80, '*'); + LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code " + "generation but "; + LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level + << "!"; + LOG(ERROR) << "(Supported configuration is " + "--xla_backend_optimization_level >= 2.)"; + LOG(ERROR) << std::string(80, '*'); + } AddOptimizationPasses(opt_level, /*size_level=*/0, target_machine.get(), &module_passes, @@ -453,18 +453,21 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // * 3-6 gives similar results as 2; // * >6 start hurting the performance of at least dot product kernels. // - // TODO(jingyue): The current threshold only considers the numbr of IR + // TODO(jingyue): The current threshold only considers the number of IR // instructions which do not accurately reflect the true cost. We need a // better cost model. FeedLLVMWithFlags({"-bonus-inst-threshold=2"}); - // TODO(b/22073864): Increase limit when scan memory dependency. - // This helps to reduce more redundant load instructions. + // Increase limit when scanning memory dependencies. This helps to reduce + // more redundant load instructions. // // The specific value is currently large enough for s3d in shoc benchmark, // which contains a lot of load instructions and many arithmetic instructions // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + // Use div.approx -- it matters for some float-division heavy benchmarks. + FeedLLVMWithFlags({"-nvptx-prec-divf32=0"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); // Initialize the NVPTX target; it's the only target we link with, so call its diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 835924024b7b7de79624a369a69b07d72ac751ab..01fddcede64d1bb02ab89db5fc9524893c2d47a4 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -41,50 +41,7 @@ GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, HloInstruction* instr2) { - auto get_element_instr = - [&](const HloInstruction* instr) -> const HloInstruction* { - const HloInstruction* element_instr = instr; - if (instr->opcode() == HloOpcode::kFusion) { - auto fused_expression_root = instr->fused_expression_root(); - if (instr->IsMultiOutputFusion()) { - // If possible, we want to pick a reduce operand of the fusion root, - // because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionToVector(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - } else { - element_instr = fused_expression_root; - } - } - return element_instr; - }; - - auto get_element_shape = [&](const HloInstruction* element_instr) { - // Special handling of kReduce instructions -- the fusion - // applies to the first operand. - if (IsReductionToVector(*element_instr)) { - return element_instr->operand(0)->shape(); - } - return element_instr->shape(); - }; - - // The shapes in all tuple operands should agree, unless it is a reduce. - // In that case, the operand of the reduce needs to have the same shape - // as the other tuple operands, but also we need to compare the output - // shapes of the reduces. - auto* element_instr_1 = get_element_instr(instr1); - auto* element_instr_2 = get_element_instr(instr2); - if (element_instr_1->opcode() == HloOpcode::kReduce && - element_instr_2->opcode() == HloOpcode::kReduce && - !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { - return false; - } - // The elementwise output shapes must be the same (including layout). - return ShapeUtil::EqualIgnoringFpPrecision( - get_element_shape(element_instr_1), get_element_shape(element_instr_2)); + return ShapesCompatibleForMultiOutputFusion(*instr1, *instr2); } bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { @@ -140,6 +97,18 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, return false; } + // The emitter only supports in-place DUS for fusions with a single DUS at the + // root. Don't sibling fuse DUS for now. + // TODO(b/119178699): Multi-output fusing DUS can improve performance if we + // share the input and output buffers and add support to the emitter. + if (instr1->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice || + (instr2->opcode() == HloOpcode::kFusion && + instr2->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice)) { + return false; + } + // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } @@ -180,6 +149,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is not fusible."; continue; } + // Never multi-output fuse constants. To the extent that we want to fuse + // constants, that should be handled by the regular fusion pass. + if (producer->opcode() == HloOpcode::kConstant) { + VLOG(3) << producer->name() << " is a constant."; + continue; + } const bool is_loop_fusion = producer->opcode() == HloOpcode::kFusion && producer->fusion_kind() == HloInstruction::FusionKind::kLoop; @@ -187,7 +162,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is not a loop fusion."; continue; } - if (!ShapesCompatibleForFusion(producer, consumer)) { + if (!ShapesCompatibleForMultiOutputFusion(*producer, *consumer)) { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 8a6e5327e082791ff857a89e840c6a4f045f0edb..d16c87ba5c63aa582753fe949e9e39ee2d8b81e5 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -505,7 +505,7 @@ TEST_F(MultiOutputFusionTest, p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) broadcast = f16[2,2,2]{2,1,0} broadcast(f16[] c0), dimensions={} - greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + greater-than = pred[2,2,2]{2,1,0} greater-than(f16[2,2,2]{2,1,0} p1.1, f16[2,2,2]{2,1,0} broadcast) p0.1 = f16[2,2,2]{2,1,0} parameter(0) ROOT select = f16[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f16[2,2,2]{2,1,0} p0.1, f16[2,2,2]{2,1,0} broadcast) } @@ -580,7 +580,7 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { // ... // where each of the (pi * pj)'s is represented as a fusion node so that // multi-output fusion will pay attention to it. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {10, 100}); @@ -621,5 +621,39 @@ TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) { } } +TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) { + auto module = ParseHloString(R"(HloModule dus_mof + fusion.1 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + fusion.2 { + p.0 = f16[50,96,1024]{2,1,0} parameter(0) + p.1 = s32[1]{0} parameter(1) + p.2 = f16[1,96,1024]{2,1,0} parameter(2) + c.0 = s32[] constant(0) + pad = s32[3]{0} pad(p.1, c.0), padding=0_2 + ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + } + + ENTRY entry { + p.00 = f16[50,96,1024]{2,1,0} parameter(0) + p.01 = f16[50,96,1024]{2,1,0} parameter(1) + p.1 = s32[1]{0} parameter(2) + p.2 = f16[1,96,1024]{2,1,0} parameter(3) + + f1 = f16[50,96,1024] fusion(p.00, p.1, p.2), kind=kLoop, calls=fusion.1 + f2 = f16[50,96,1024] fusion(p.01, p.1, p.2), kind=kLoop, calls=fusion.2 + ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2) + })") + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index b4ae2e42c7c34774b86d0bf69eef4dba390c0cc5..f3e17d888242a36c268dcbfa0d6530f80cedceb0 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -38,9 +38,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -54,18 +56,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -75,7 +77,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -128,6 +129,7 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { << potential_libdevice_dir; } + LOG(WARNING) << "Unable to find libdevice dir. Using '.'"; // Last resort: maybe in the current folder. return "."; } @@ -172,15 +174,16 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); + pipeline.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); - pipeline.AddPass(); - - pass.AddPass( - /*is_layout_sensitive=*/false, + AlgebraicSimplifierOptions options( [](const Shape&, const Shape&) { return false; }); + options.set_enable_permutation_sort_replacement(true); + pass.AddPass(options); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -204,21 +207,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (PadInsertion). + // (CudnnConvPaddingLegalization). HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - // PadForTensorCores leaves behind unnecessary tuple/get-tuple-element - // pairs that TupleSimplifier fixes. + pipeline.AddPass(); + // CudnnConvPadForTensorCores leaves behind unnecessary + // tuple/get-tuple-element pairs that TupleSimplifier fixes. pipeline.AddPass(); } - // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add - // instructions which can be simplified by constant folding. + // CudnnConvRewriter, CudnnConvPaddingLegalization and + // CudnnConvPadForTensorCores may add instructions which can be simplified + // by constant folding. pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -239,21 +243,27 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, + AlgebraicSimplifierOptions options( /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { return true; }); + options.set_is_layout_sensitive(true); + options.set_enable_permutation_sort_replacement(true); + pipeline.AddPass>(options); // Choose the fastest algorithm for each conv. // // We pick the algorithm before fusion so we can generate better HLO. After - // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // CudnnConvRewriter, our convolutions are CustomCalls which return a // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of // scratch: // @@ -271,12 +281,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // However, if we were to run CudnnConvAlgorithmPicker after fusion // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass( - stream_exec, device_allocator, compiler); + pipeline.AddPass(stream_exec, device_allocator, + compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -286,8 +296,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // We try to split variadic ops with many parameters into several such ops + // to avoid exceeding the parameter space. + fusion.AddPass(); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); @@ -298,8 +315,11 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -325,8 +345,12 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -401,7 +425,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { "prefers >= 9.2.88). Compilation of XLA kernels below will likely " "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " "binary is sufficient."; - } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot @@ -455,13 +479,15 @@ void WarnIfBadDriverJITVersion() { // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, - int cc_minor) { + int cc_minor, + bool disable_ptx_optimizations) { tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); - VLOG(2) << "Using ptxas at " << ptxas_path; + VLOG(2) << "Checking ptxas at " << ptxas_path; auto env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + VLOG(2) << "Using ptxas at " << ptxas_path; WarnIfBadPtxasVersion(ptxas_path); @@ -494,6 +520,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } + if (disable_ptx_optimizations) { + ptxas_args.push_back("-O0"); + } ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); @@ -527,14 +556,17 @@ StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { // We dump the post-optimization HLO in RunBackend so no need to dump it here. - VLOG(2) << "*** HLO Before Optimization"; - XLA_VLOG_LINES(2, module->ToString()); + VLOG(3) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(3, module->ToString()); XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); + + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + return std::move(module); } @@ -545,8 +577,6 @@ StatusOr> NVPTXCompiler::RunBackend( TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); - llvm::LLVMContext llvm_context; std::string buffer; llvm::raw_string_ostream error(buffer); @@ -586,8 +616,8 @@ StatusOr> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); - VLOG(2) << "*** HLO After Optimization"; - XLA_VLOG_LINES(2, module->ToString()); + VLOG(3) << "*** HLO After Optimization"; + XLA_VLOG_LINES(3, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); if (!xla_dump_optimized_hlo_proto_to.empty()) { @@ -617,10 +647,10 @@ StatusOr> NVPTXCompiler::RunBackend( string ir_module_string_before_opt; const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - if (VLOG_IS_ON(2) || embed_ir_in_executable) { + if (VLOG_IS_ON(3) || embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); - VLOG(2) << "LLVM module before optimizations:"; - XLA_VLOG_LINES(2, ir_module_string_before_opt); + VLOG(3) << "LLVM module before optimizations:"; + XLA_VLOG_LINES(3, ir_module_string_before_opt); } const string& ir_dump_directory = @@ -664,6 +694,8 @@ StatusOr> NVPTXCompiler::RunBackend( } libdevice_dir = cached_libdevice_dir_; } + VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n"; + int cc_major, cc_minor; if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor)) { @@ -690,10 +722,10 @@ StatusOr> NVPTXCompiler::RunBackend( if (user_post_optimization_hook_) { TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); } - VLOG(2) << "LLVM module after optimizations:"; - XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); - VLOG(2) << "PTX:"; - XLA_VLOG_LINES(2, ptx); + VLOG(3) << "LLVM module after optimizations:"; + XLA_VLOG_LINES(3, llvm_ir::DumpModuleToString(llvm_module)); + VLOG(3) << "PTX:"; + XLA_VLOG_LINES(3, ptx); // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { @@ -711,14 +743,15 @@ StatusOr> NVPTXCompiler::RunBackend( } } - const std::vector cubin = - CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + const std::vector cubin = CompilePtxOrGetCachedResult( + ptx, cc_major, cc_minor, + module->config().debug_options().xla_gpu_disable_ptxas_optimizations()); auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); - VLOG(2) << "Printing the thunk schedule..."; - XLA_VLOG_LINES(2, thunk_schedule->ToString()); + VLOG(3) << "Printing the thunk schedule..."; + XLA_VLOG_LINES(3, thunk_schedule->ToString()); std::unique_ptr profile_index_map; std::unique_ptr profile_printer; @@ -729,8 +762,8 @@ StatusOr> NVPTXCompiler::RunBackend( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = absl::make_unique(*module); - profile_printer = - CreateHloProfilePrinterData(*profile_index_map, cost_analysis); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); } auto* gpu_executable = new GpuExecutable( @@ -744,9 +777,9 @@ StatusOr> NVPTXCompiler::RunBackend( return std::unique_ptr(gpu_executable); } -std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, - int cc_minor) { +std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; @@ -774,8 +807,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, if (inserted) { CHECK(!cache_value->compilation_done); if (!ptx.empty()) { - StatusOr> maybe_cubin = - CompilePtx(*cache_ptx, cc_major, cc_minor); + StatusOr> maybe_cubin = CompilePtx( + *cache_ptx, cc_major, cc_minor, disable_ptx_optimizations); if (maybe_cubin.ok()) { cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); VLOG(2) << "Compiled PTX size:" << ptx.size() @@ -788,7 +821,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, // binaries are not available. We don't want to spam logs with // identical warnings in this case. - // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N // for more general usage. static std::atomic warning_done(false); log_warning = !warning_done.exchange(true); @@ -820,9 +853,8 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, } StatusOr>> -NVPTXCompiler::CompileAheadOfTime( - std::vector> module, - const AotCompilationOptions& options) { +NVPTXCompiler::CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) { return Unimplemented( "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index c4a0b727cd3d9ae0af61c1752c1608cd4fb65d2d..be5e31a50112686841e6f18b76f382a56e61bafc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -59,7 +59,7 @@ class NVPTXCompiler : public LLVMCompiler { DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> module, + CompileAheadOfTime(std::unique_ptr module_group, AotCompilationOptions const& options) override; se::Platform::Id PlatformId() const override; @@ -97,8 +97,9 @@ class NVPTXCompiler : public LLVMCompiler { // Tries to compile the given ptx string to cubin. Returns a vector with the // compiled cubin. If compilation was unsuccessful, returns an empty vector. - std::vector CompilePtxOrGetCachedResult(const string& ptx, - int cc_major, int cc_minor); + std::vector CompilePtxOrGetCachedResult( + const string& ptx, int cc_major, int cc_minor, + bool disable_ptx_optimizations); // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} // -> cubin so we don't recompile the same ptx twice. This is important for diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc deleted file mode 100644 index e3869b5c368957571219a39600214140022a7318..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" - -namespace xla { -namespace gpu { - -// We want the input/output feature counts of an f16 conv to be factors of 8, -// because without this cudnn can't use tensor cores on the conv. -static constexpr int64 kDesiredNumFeaturesFactor = 8; - -// We won't pad a conv if doing so increases the total number of bytes in the -// lhs, rhs, or result by more than this amount. -// -// TODO(jlebar): This number was tuned experimentally. It represents a -// compromise on our current benchmarks; it speeds some up significantly, and -// doesn't slow any down. But we can observe by changing this value that -// there's additional room for speedups. Achieving those speedups without also -// slowing other things down will likely require a more sophisticated heuristic, -// possibly some form of auto-tuning. -// -// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4" -// special case inside PadShape won't fire. -static constexpr double kMaxBytesTouchedIncrease = 1.35; - -// Pads the given dimensions in the given shape up to a multiple of -// kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, absl::Span dims) { - for (int64 dim : dims) { - int64 dim_to_pad_size = s.dimensions(dim); - - // Round dim_to_pad_size up to the next multiple of - // kDesiredNumFeaturesFactor. - // - // Special case: dims of size 3 are rounded up to 4, not - // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia), - // this helps, but as of writing, it's not supported by anything in the - // cudnn docs. - int64 new_dim_to_pad_size; - if (dim_to_pad_size == 3) { - new_dim_to_pad_size = 4; - } else { - new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); - } - - s.set_dimensions(dim, new_dim_to_pad_size); - } - return s; -} - -// Creates and returns an HLO that zero-pads one or more dimensions in the given -// instruction so that its shape is equal to the given shape. -// -// Padding is added to the end of each relevant dimension. -// -// If the instruction already has the given shape, simply returns it without an -// intervening pad. -static HloInstruction* PadInstruction(HloInstruction* instr, - const Shape& new_shape) { - HloComputation* comp = instr->parent(); - - const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - - PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); - - bool added_padding = false; - for (int64 dim = 0; dim < ShapeUtil::Rank(shape); ++dim) { - if (shape.dimensions(dim) == new_shape.dimensions(dim)) { - continue; - } - CHECK_GT(new_shape.dimensions(dim), shape.dimensions(dim)); - pad_config.mutable_dimensions(dim)->set_edge_padding_high( - new_shape.dimensions(dim) - shape.dimensions(dim)); - added_padding = true; - } - - if (!added_padding) { - return instr; - } - return comp->AddInstruction( - HloInstruction::CreatePad(new_shape, instr, zero, pad_config)); -} - -// Pads the input/output feature dimensions of the given cudnn convolution -// custom-call to be multiples of kDesiredNumFeaturesFactor. -static StatusOr PadFeaturesDims(HloInstruction* conv) { - CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0)) - << "conv must use 0 scratch bytes, i.e. this pass must be run " - "before CudnnConvolutionAlgorithmPicker."; - - const auto& target = conv->custom_call_target(); - const auto& dnums = conv->convolution_dimension_numbers(); - auto* lhs = conv->mutable_operand(0); - auto* rhs = conv->mutable_operand(1); - const Shape& result_shape = conv->shape().tuple_shapes(0); - - Shape new_lhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardFilterCallTarget) { - // LHS is "input". - return PadShape(lhs->shape(), {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardInputCallTarget); - // LHS is "output". - return PadShape(lhs->shape(), {dnums.output_feature_dimension()}); - }(); - - Shape new_rhs_shape = [&] { - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBackwardInputCallTarget) { - // RHS is "filter". - return PadShape(rhs->shape(), {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // RHS is "output". - return PadShape(rhs->shape(), {dnums.output_feature_dimension()}); - }(); - - if (ShapeUtil::Equal(lhs->shape(), new_lhs_shape) && - ShapeUtil::Equal(rhs->shape(), new_rhs_shape)) { - VLOG(3) << "No need to pad features of " << conv->ToString(); - return false; - } - - Shape new_result_shape = [&] { - if (target == kCudnnConvForwardCallTarget) { - // Result is "output". - return PadShape(result_shape, {dnums.output_feature_dimension()}); - } - if (target == kCudnnConvBackwardInputCallTarget) { - // Result is "input". - return PadShape(result_shape, {dnums.input_feature_dimension()}); - } - CHECK_EQ(target, kCudnnConvBackwardFilterCallTarget); - // Result is "filter". - return PadShape(result_shape, {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}); - }(); - - // Check that padding wouldn't increase the total bytes read/written by this - // operation too much. - auto check_size_increase = [&](const Shape& old_shape, - const Shape& new_shape) { - int64 old_bytes = ShapeUtil::ByteSizeOf(old_shape); - int64 new_bytes = ShapeUtil::ByteSizeOf(new_shape); - if (new_bytes <= old_bytes * kMaxBytesTouchedIncrease) { - return true; - } - VLOG(3) << "Not padding convolution; doing so would change input / result " - "shape from " - << ShapeUtil::HumanString(old_shape) << " to " - << ShapeUtil::HumanString(new_shape) << ", a size increase of " - << new_bytes / static_cast(old_bytes) << "x > " - << kMaxBytesTouchedIncrease << "x: " << conv->ToString(); - return false; - }; - if (!check_size_increase(lhs->shape(), new_lhs_shape) || - !check_size_increase(rhs->shape(), new_rhs_shape) || - !check_size_increase(result_shape, new_result_shape)) { - return false; - } - - // OK, let's do the transformation! - - auto* new_lhs = PadInstruction(lhs, new_lhs_shape); - auto* new_rhs = PadInstruction(rhs, new_rhs_shape); - CHECK(new_lhs != lhs || new_rhs != rhs) - << "We should have had to pad either LHS or RHS."; - - auto add = [&](std::unique_ptr new_instr) { - return conv->parent()->AddInstruction(std::move(new_instr)); - }; - - Shape new_conv_shape = ShapeUtil::MakeTupleShape( - {new_result_shape, ShapeUtil::MakeShape(U8, {0})}); - auto* new_conv = - add(conv->CloneWithNewOperands(new_conv_shape, {new_lhs, new_rhs})); - - // Slice the new conv result if necessary, keeping in mind that new_conv has - // tuple shape (new_result_shape, u8[0]). - if (!ShapeUtil::Equal(result_shape, new_result_shape)) { - std::vector start_indices(result_shape.dimensions_size(), 0); - std::vector end_indices(result_shape.dimensions().begin(), - result_shape.dimensions().end()); - std::vector strides(result_shape.dimensions_size(), 1); - - auto* new_conv_result = add( - HloInstruction::CreateGetTupleElement(new_result_shape, new_conv, 0)); - auto* empty_temp_buffer = - add(HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); - auto* sliced_result = add(HloInstruction::CreateSlice( - result_shape, new_conv_result, start_indices, end_indices, strides)); - new_conv = - add(HloInstruction::CreateTuple({sliced_result, empty_temp_buffer})); - } - - VLOG(2) << "Padded features of " << conv->ToString() << ", replaced with " - << new_conv->ToString(); - TF_RETURN_IF_ERROR(conv->parent()->ReplaceInstruction(conv, new_conv)); - return true; -} - -static std::vector GetRelevantConvs(HloComputation* comp) { - std::vector convs; - for (HloInstruction* instr : comp->instructions()) { - if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16 && - // TODO(timshen): Disable for fused conv for now. Implement it if it's - // needed. - Cast(instr)->custom_call_target() != - kCudnnConvBiasActivationForwardCallTarget) { - convs.push_back(instr); - } - } - return convs; -} - -StatusOr PadForTensorCores::Run(HloModule* module) { - bool changed = false; - for (HloComputation* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* conv : GetRelevantConvs(comp)) { - TF_ASSIGN_OR_RETURN(bool result, PadFeaturesDims(conv)); - changed |= result; - } - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h deleted file mode 100644 index e592a3774ec28605fda912298c74ca7976ff99ac..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ - -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// Ensures that f16 cudnn convolutions have input/output channel dimensions that -// are multiples of 8, inserting pads/slices as necessary. -// -// This is useful primarily for Volta and newer GPUs, where tensor cores can -// only be used if the channel dims are multiples of 8. It's probably the -// opposite of useful on other GPUs, so you should check what GPU you're -// targeting before running this pass. -// -// TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloModulePass { - public: - absl::string_view name() const override { return "pad for tensor cores"; } - - StatusOr Run(HloModule* module) override; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc deleted file mode 100644 index 5c92b0dcb873b873074704dca8f27d4067b070df..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h" - -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace gpu { -namespace { - -namespace op = xla::testing::opcode_matchers; -using ::testing::_; - -class PadForTensorCoresTest : public HloVerifiedTestBase {}; - -TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - input = f16[10,20,30,41] parameter(0) - filter = f16[2,2,41,40] parameter(1) - ROOT result = (f16[10,20,30,40], u8[0]) custom-call(input, filter), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - - SCOPED_TRACE(module().ToString()); - EXPECT_THAT(root, op::CustomCall(kCudnnConvForwardCallTarget, - op::Pad(op::Parameter(0), _), - op::Pad(op::Parameter(1), _))); - EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), - ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); - EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), - ShapeUtil::MakeShape(F16, {2, 2, 48, 40}))); -} - -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - output = f16[10,20,30,41] parameter(0) - filter = f16[2,2,40,41] parameter(1) - ROOT result = (f16[10,20,30,40], u8[0]) custom-call(output, filter), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convBackwardInput" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget, - op::Pad(op::Parameter(0), _), - op::Pad(op::Parameter(1), _))); - EXPECT_TRUE(ShapeUtil::Equal(root->operand(0)->shape(), - ShapeUtil::MakeShape(F16, {10, 20, 30, 48}))); - EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->shape(), - ShapeUtil::MakeShape(F16, {2, 2, 40, 48}))); -} - -TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - input = f16[10,20,30,40] parameter(0) - filter = f16[2,2,40,41] parameter(1) - ROOT result = (f16[10,20,30,41], u8[0]) custom-call(input, filter), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convForward" - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall( - kCudnnConvForwardCallTarget, op::Parameter(0), - op::Pad(op::Parameter(1), _)))), - _)); -} - -TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - output = f16[10,20,30,40] parameter(0) - filter = f16[2,2,41,40] parameter(1) - result = (f16[10,20,30,41], u8[0]) custom-call(output, filter), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convBackwardInput" - ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(op::Tuple( - op::Slice(op::GetTupleElement(op::CustomCall( - kCudnnConvBackwardInputCallTarget, op::Parameter(0), - op::Pad(op::Parameter(1), _)))), - _))); -} - -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - input = f16[10,20,30,41] parameter(0) - output = f16[10,20,30,40] parameter(1) - result = (f16[2,2,41,40], u8[0]) custom-call(input, output), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convBackwardFilter" - ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(op::Tuple( - op::Slice(op::GetTupleElement(op::CustomCall( - kCudnnConvBackwardFilterCallTarget, - op::Pad(op::Parameter(0), _), op::Parameter(1)))), - _))); -} - -TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) { - ParseAndVerifyModule(R"( - HloModule TestModule - - ENTRY TestComputation { - input = f16[10,20,30,40] parameter(0) - output = f16[10,20,30,41] parameter(1) - result = (f16[2,2,40,41], u8[0]) custom-call(input, output), - window={size=2x2}, dim_labels=b01f_01io->b01f, - custom_call_target="__cudnn$convBackwardFilter" - ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0 - })"); - EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie()); - auto* root = module().entry_computation()->root_instruction(); - EXPECT_THAT(root, op::GetTupleElement(op::Tuple( - op::Slice(op::GetTupleElement(op::CustomCall( - kCudnnConvBackwardFilterCallTarget, - op::Parameter(0), op::Pad(op::Parameter(1), _)))), - _))); -} - -} // anonymous namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc deleted file mode 100644 index b42a19e3a2200e917f8040be183b8d79c9e4e161..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ /dev/null @@ -1,414 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -namespace { -bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || - conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); - return window_util::HasSymmetricPadding(conv.window()) && - !window_util::HasNegativePadding(conv.window()) && - !window_util::HasDilation(conv.window()); -} - -// If the (positive and negative) padding on the input operand of a convolution -// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and -// dilation), returns kPad and/or kSlice instructions that explicitly apply the -// padding; otherwise returns the original input operand. When there is both -// positive padding (including dilation) and negative padding, we insert both -// kPad and kSlice. -HloInstruction* MaybePaddedAndSlicedInput( - const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* input) { - HloComputation* computation = input->parent(); - if (!window_util::HasSymmetricPadding(conv_window) || - window_util::HasBaseDilation(conv_window)) { - // If padding is uneven or has dilation, we insert a kPad instruction that - // applies positive padding and dilation. - // - // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of - // moving all the padding into an explicit pad op, we should keep as much - // padding inside of cudnn as possible, on the assumption that padding - // within cudnn is basically free, whereas a kPad's cost increases as the - // amount of padding increases. - PaddingConfig padding_config = - MakeNoPaddingConfig(input->shape().dimensions_size()); - for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.input_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_edge_padding_low( - std::max(0LL, conv_window.dimensions(i).padding_low())); - padding_config.mutable_dimensions(dim)->set_edge_padding_high( - std::max(0LL, conv_window.dimensions(i).padding_high())); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).base_dilation() - 1); - } - PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); - input = MakePadHlo(input, padding, padding_config).ValueOrDie(); - } - - if (window_util::HasNegativePadding(conv_window)) { - // If the window has negative padding, insert a kSlice that explicitly - // applies negative padding. - // - // For each dimension, initialize the start index to 0 and the limit index - // to the size of that dimension. - std::vector start_indices(input->shape().dimensions_size(), 0); - std::vector limit_indices(input->shape().dimensions().begin(), - input->shape().dimensions().end()); - std::vector strides(input->shape().dimensions_size(), 1); - for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.input_spatial_dimensions(i); - // If dimension "dim" has negative padding, increase the start index or - // decrement the limit index by the amount of negative padding. - start_indices[dim] += - std::max(0LL, -conv_window.dimensions(i).padding_low()); - limit_indices[dim] -= - std::max(0LL, -conv_window.dimensions(i).padding_high()); - } - - input = - MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); - } - - return input; -} - -// If the padding on the kernel operand of a convolution can't be folded into a -// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that -// explicitly applies the padding; otherwise returns the original kernel -// operand. -HloInstruction* MaybePaddedKernel(const Window& conv_window, - const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* kernel) { - if (!window_util::HasWindowDilation(conv_window)) { - return kernel; - } - - // Compute the shape and padding config of the pad to be inserted. - PaddingConfig padding_config; - for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { - padding_config.add_dimensions(); - } - for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.kernel_spatial_dimensions(i); - padding_config.mutable_dimensions(dim)->set_interior_padding( - conv_window.dimensions(i).window_dilation() - 1); - } - - HloComputation* computation = kernel->parent(); - PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); - return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); -} -} // namespace - -bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { - if (IsForwardConvolutionCanonical(*conv)) { - return false; - } - - // Insert slices and/or pads between the convolution and its input and/or - // kernel operand. - HloInstruction* new_input = MaybePaddedAndSlicedInput( - conv->window(), conv->convolution_dimension_numbers(), - conv->mutable_operand(0)); - HloInstruction* new_kernel = - MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(), - conv->mutable_operand(1)); - - // Remove the padding from convolution's window field. These paddings are - // made explicit with the inserted pads. - Window new_conv_window = conv->window(); - for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) { - WindowDimension* dim = new_conv_window.mutable_dimensions(i); - - // The size of the kernel may have changed so update the Window to match. - dim->set_size(new_kernel->shape().dimensions( - conv->convolution_dimension_numbers().kernel_spatial_dimensions(i))); - dim->set_padding_low(0); - dim->set_padding_high(0); - dim->set_base_dilation(1); - dim->set_window_dilation(1); - } - - // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract - // out the shape of conv_result. - VLOG(1) << "Canonicalizing forward conv"; - std::vector operands(conv->operands().begin(), - conv->operands().end()); - operands[0] = new_input; - operands[1] = new_kernel; - auto new_conv = conv->parent()->AddInstruction( - conv->CloneWithNewOperands(conv->shape(), operands)); - new_conv->set_window(new_conv_window); - VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " - << new_conv->ToString(); - TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); - return true; -} - -namespace { -void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) { - window_dim->set_padding_low(window_dim->padding_low() + delta); -} - -void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { - window_dim->set_padding_high(window_dim->padding_high() + delta); -} -} // namespace - -bool PadInsertion::CanonicalizeBackwardFilterConvolution( - HloInstruction* backward_conv) { - CHECK_EQ(backward_conv->custom_call_target(), - kCudnnConvBackwardFilterCallTarget); - if (window_util::HasSymmetricPadding(backward_conv->window())) { - return false; - } - - // A backward filter convolution with uneven padding can be canonicalized to - // one with even padding by padding the activations (input) beforehand. For - // example, - // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2) - // is equivalent to - // ABCD0 = Pad(ABCD, padding_high=1) - // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) - // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* input = backward_conv->mutable_operand(0); - Window new_backward_conv_window = backward_conv->window(); - // input_padding_config is the config of the kPad to be inserted. - PaddingConfig input_padding_config = - MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers backward_conv_dnums = - backward_conv->convolution_dimension_numbers(); - for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { - int64 padding_low = backward_conv->window().dimensions(i).padding_low(); - int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - if (padding_low < 0 || padding_high < 0) { - // TODO(b/32744257): The following canonicalization wouldn't remove - // negative padding in a backward convolution, and would therefore cause - // cuDNN convolution (which doesn't support negative padding) to fail. - return false; - } - // Compute the new, even padding for the backward conv operation. - int64 new_conv_padding = std::min(padding_low, padding_high); - int64 dim = backward_conv_dnums.input_spatial_dimensions(i); - input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( - padding_low - new_conv_padding); - input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( - padding_high - new_conv_padding); - - // Since we move some padding from the backward convolution to the kPad, we - // need to accordingly reduce the padding amount of the backward convolution - // and its inner forward convolution. - auto* new_dim = new_backward_conv_window.mutable_dimensions(i); - new_dim->set_padding_low(new_conv_padding); - new_dim->set_padding_high(new_conv_padding); - } - - // Create a new backward convolution replacing the old one. - HloComputation* computation = backward_conv->parent(); - HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(input->shape().element_type()))); - HloInstruction* padded_input = - MakePadHlo(input, padding, input_padding_config).ValueOrDie(); - - // The shape of the backward_conv CustomCall is a tuple (conv_result, - // scratch_buffer). Extract out the shape of conv_result. - HloInstruction* new_backward_conv = - computation->AddInstruction(backward_conv->CloneWithNewOperands( - backward_conv->shape(), {padded_input, output})); - new_backward_conv->set_window(new_backward_conv_window); - - VLOG(1) << "Canonicalizing backward filter conv"; - VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << new_backward_conv->ToString(); - - TF_CHECK_OK( - computation->ReplaceInstruction(backward_conv, new_backward_conv)); - return true; -} - -bool PadInsertion::CanonicalizeBackwardInputConvolution( - HloInstruction* backward_conv) { - if (window_util::HasSymmetricPadding(backward_conv->window())) { - return false; - } - - Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers backward_conv_dnums = - backward_conv->convolution_dimension_numbers(); - - // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). - // Get the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - - Shape new_backward_conv_shape = backward_conv_shape; - for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { - int64 padding_low = backward_conv->window().dimensions(i).padding_low(); - int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - if (padding_low < 0 || padding_high < 0) { - // TODO(b/32744257): The following canonicalization wouldn't remove - // negative padding in a backward convolution, and would therefore cause - // cuDNN convolution (which doesn't support negative padding) to fail. - return false; - } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger activations which get sliced later. - // - // For example, suppose we have a non-canonical HLO - // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1)) - // where the amount of padding low is larger, we can canonicalize it to - // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) - // [A] = Slice([B A]) - if (padding_low > padding_high) { - IncreasePaddingLowBy(padding_high - padding_low, - new_backward_conv_window.mutable_dimensions(i)); - } else if (padding_low < padding_high) { - IncreasePaddingHighBy(padding_low - padding_high, - new_backward_conv_window.mutable_dimensions(i)); - } - // Decreasing the padding by X *increases* the size of our output by X. - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); - new_backward_conv_shape.set_dimensions( - dim, new_backward_conv_shape.dimensions(dim) + - std::abs(padding_low - padding_high)); - } - - // Create a new backward convolution replacing the old one. - HloComputation* computation = backward_conv->parent(); - HloInstruction* output = backward_conv->mutable_operand(0); - HloInstruction* filter = backward_conv->mutable_operand(1); - - HloInstruction* new_backward_conv_call = - computation->AddInstruction(backward_conv->CloneWithNewOperands( - ShapeUtil::MakeTupleShape( - {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), - {output, filter})); - new_backward_conv_call->set_window(new_backward_conv_window); - - // The CustomCall created above returns a tuple (conv_result, scratch_memory). - // Extract out the two elements. - HloInstruction* new_backward_conv = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - new_backward_conv_shape, new_backward_conv_call, 0)); - HloInstruction* new_backward_conv_scratch = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - new_backward_conv_call->shape().tuple_shapes(1), - new_backward_conv_call, 1)); - - // Slice the new backward convolution. - // - // Initialize start_indices and limit_indices as no slicing. - std::vector start_indices(new_backward_conv->shape().dimensions_size(), - 0LL); - std::vector limit_indices( - new_backward_conv->shape().dimensions().begin(), - new_backward_conv->shape().dimensions().end()); - std::vector strides(new_backward_conv->shape().dimensions_size(), 1LL); - for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { - int64 padding_low = backward_conv->window().dimensions(i).padding_low(); - int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.output_spatial_dimensions(i); - if (padding_low > padding_high) { - // If the amount of low padding (of the old backward convolution) is - // larger, we internally pad the low end of the activations and slice - // internal padding out here. - start_indices[dim] += padding_low - padding_high; - } else if (padding_low < padding_high) { - // If the amount of high padding is larger, we slice out the internal - // padding on the high end. - limit_indices[dim] -= padding_high - padding_low; - } - } - - // Replace the old backward convolution with the slice. - Shape slice_shape = - ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, - limit_indices, strides) - .ConsumeValueOrDie(); - CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) - << ShapeUtil::HumanString(slice_shape) << " vs " - << ShapeUtil::HumanString(backward_conv_shape); - - HloInstruction* slice = computation->AddInstruction( - HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, - start_indices, limit_indices, strides)); - HloInstruction* new_tuple = computation->AddInstruction( - HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); - - VLOG(1) << "Canonicalizing backward input conv"; - VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << new_tuple->ToString(); - - TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); - return true; -} - -StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { - bool changed = false; - std::vector convs; - for (auto* instr : computation->instructions()) { - if (IsCustomCallToDnnConvolution(*instr)) { - convs.push_back(instr); - } - } - for (HloInstruction* instruction : convs) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget || - target == kCudnnConvBiasActivationForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } - } - return changed; -} - -StatusOr PadInsertion::Run(HloModule* module) { - bool changed = false; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); - changed |= result; - } - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h deleted file mode 100644 index 25cdf64c4cf01300869044d3e4d7c34c85626a5a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ - -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace gpu { - -// An HLO pass that canonicalizes convolution instructions for GPU codegen. It -// inserts Pad instructions before Convolution instructions with uncanonicalized -// padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloModulePass { - public: - absl::string_view name() const override { return "pad insertion"; } - - StatusOr Run(HloModule* module) override; - - private: - StatusOr RunOnComputation(HloComputation* computation); - // Returns if any changes are made to the parent computation. - bool CanonicalizeForwardConvolution(HloInstruction* conv); - bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); - bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv); -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 5b6cf2c04d05378a363232e33a6df6432cd6848e..4775baf44aecfe6adaf2bf0d2791595436635b16 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -122,7 +122,7 @@ std::unique_ptr AssignStreams(const HloModule& module) { auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = - computation.ComputeReachability(); + HloReachabilityMap::Build(&computation); std::vector seen_gemms; // The execution of different RNG Hlo instructions in the same module updates // a common global variable. To avoid a race condition, we simply assign all diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index c4f43cc9a614283acb376b5f98e4976615b590ad..31a5d7a8c04e9863830e2026fc73cd7ded8c322e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,16 +21,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloVerifiedTestBase { +class StreamAssignmentTest : public HloTestBase { protected: - std::unique_ptr CreateNewModule() { + std::unique_ptr CreateNewVerifiedModule() { HloModuleConfig config; auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); @@ -55,7 +55,7 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* dot2 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(dot2)); std::unique_ptr assignment = AssignStreams(*module); @@ -76,7 +76,7 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(add)); std::unique_ptr assignment = AssignStreams(*module); @@ -120,7 +120,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { HloInstruction* d40 = builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build(d40)); std::unique_ptr assignment = AssignStreams(*module); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a7255335672a3622d122e9fc5ebfab236a5ba895..d798b31643782eb25bba08227e29903ec0e7a597 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -37,7 +37,7 @@ cc_library( hdrs = ["gpu_codegen_test.h"], tags = tf_cuda_tests_tags(), deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", @@ -211,15 +211,13 @@ tf_cc_test( ) tf_cc_test( - name = "cudnn_fused_convolution_rewriter_test", - srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + name = "gpu_atomic_test", + srcs = ["gpu_atomic_test.cc"], tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc deleted file mode 100644 index 5632cac1862e21825888d94ab1eee5e1c9fd6800..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc +++ /dev/null @@ -1,283 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/strings/str_replace.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -class CudnnFusedConvolutionRewriterTest : public HloTestBase { - protected: - string GetOptimizedHlo(absl::string_view hlo_string) { - return backend() - .compiler() - ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) - .ConsumeValueOrDie(), - backend().default_stream_executor(), - backend().memory_allocator()) - .ConsumeValueOrDie() - ->ToString(); - } - - void TestMatchWithAllTypes(absl::string_view hlo_string) { - for (absl::string_view type : {"f16", "f32", "f64"}) { - const string hlo_with_new_type = - absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); - const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); - EXPECT_EQ(absl::string_view::npos, - optimized_hlo_string.find("__cudnn$convForward")) - << optimized_hlo_string; - EXPECT_NE(absl::string_view::npos, - optimized_hlo_string.find("__cudnn$convBiasActivationForward")) - << optimized_hlo_string; - EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) - << optimized_hlo_string; - } - } - - void TestNotMatchWithAllTypes(absl::string_view hlo_string) { - for (absl::string_view type : {"f16", "f32", "f64"}) { - const string hlo_with_new_type = - absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); - string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); - EXPECT_NE(absl::string_view::npos, - optimized_hlo.find("__cudnn$convForward")) - << optimized_hlo; - EXPECT_EQ(absl::string_view::npos, - optimized_hlo.find("__cudnn$convBiasActivationForward")) - << optimized_hlo; - } - } -}; - -TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { - // max(0, conv(x, w)); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} - - input = TYPE[1,17,9,9] parameter(0) - filter = TYPE[3,3,17,32] parameter(1) - - conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { - // max(0, conv(x, w) + bias); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - bias = TYPE[64] parameter(2) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} - add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { - // max(0, conv(x, w) + side_input); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input = TYPE[1,3,3,64] parameter(2) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - add1 = TYPE[1,3,3,64] add(conv, side_input) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { - // max(0, conv(x, w) + side_input + bias); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input = TYPE[1,3,3,64] parameter(2) - bias = TYPE[64] parameter(3) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} - add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) - add2 = TYPE[1,3,3,64] add(add1, side_input) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { - // max(0, 0.999994934 * conv(x, w)); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} - alpha_conv_scalar = TYPE[] constant(0.999994934) - - input = TYPE[1,17,9,9] parameter(0) - filter = TYPE[3,3,17,32] parameter(1) - - conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} - scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) - ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { - // max(0, conv(x, w) + 0.899994934 * side_input); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - alpha_side_input_scalar = TYPE[] constant(0.899994934) - alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input = TYPE[1,3,3,64] parameter(2) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) - add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { - // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - alpha_conv_scalar = TYPE[] constant(0.999994934) - alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} - alpha_side_input_scalar = TYPE[] constant(0.899994934) - alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input = TYPE[1,3,3,64] parameter(2) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) - scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) - add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, - TestScaledConvAndScaledSideInputWithBias) { - // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); - TestMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - alpha_conv_scalar = TYPE[] constant(0.999994934) - alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} - alpha_side_input_scalar = TYPE[] constant(0.899994934) - alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input = TYPE[1,3,3,64] parameter(2) - bias = TYPE[64] parameter(3) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) - scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) - broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} - add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) - add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { - // max(0.1, conv(x, w)) shouldn't match. - TestNotMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - point_one = TYPE[] constant(0.1) - point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} - - input = TYPE[1,17,9,9] parameter(0) - filter = TYPE[3,3,17,32] parameter(1) - - conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 - ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) - })"); -} - -TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { - // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. - TestNotMatchWithAllTypes(R"( - HloModule Test - - ENTRY Test { - zero = TYPE[] constant(0) - zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} - - input = TYPE[1,3,3,64] parameter(0) - filter = TYPE[3,3,64,64] parameter(1) - side_input1 = TYPE[1,3,3,64] parameter(2) - side_input2 = TYPE[1,3,3,64] parameter(3) - - conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 - add1 = TYPE[1,3,3,64] add(conv, side_input2) - add2 = TYPE[1,3,3,64] add(add1, side_input1) - ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) - })"); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b18c4c63714b4b3c06d7fa85f4a7a75b8e9ae12 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_atomic_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class GpuAtomicTest : public GpuCodegenTest {}; + +TEST_F(GpuAtomicTest, TestStore) { + const char* hlo_string = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } +)"; + + CompileAndVerifyIr(hlo_string, R"( +CHECK: store atomic{{.*}}unordered, align 4 +)"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 79e77d4c4d649020cf52ac25c220c3f90e8469b9..9e3ff8750b88d08bcbc1aae3faead5aecfa19848 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace gpu { -std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { +std::unique_ptr GpuCodegenTest::CreateNewUnverifiedModuleWithFTZ( + bool ftz) { HloModuleConfig config; - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_gpu_ftz(ftz); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index e4a3573babb7ed746504c1466f85b582aa4d044f..d917320e36363c4fa7e4c0055e8f3345cbc610a2 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -26,9 +26,9 @@ namespace gpu { // Tests that verify IR or PTX emitted by the GPU backend is as expected. class GpuCodegenTest : public LlvmIrGenTestBase { protected: - // Like HloTestBase::CreateNewModule(), with a flag for configuring the ftz - // option. - std::unique_ptr CreateNewModuleWithFTZ(bool ftz); + // Like HloTestBase::CreateNewVerifiedModule(), with a flag for configuring + // the ftz option. + std::unique_ptr CreateNewUnverifiedModuleWithFTZ(bool ftz); // Compiles the given HLO module to PTX and verifies the PTX matches the given // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 780539c164277f14c2bd964024f7c3ca179f4ada..a1ed8499040359fe7265a7317b0577a990a2234c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -46,7 +46,7 @@ TEST_F(GpuCopyTest, UseMemcpy) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); // There should not be any kernel prefixed "copy". diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index 177b94934c7f519172508b5cc6e088f908401193..5e524faab18947f5793dc2ae34e9329a446d4235 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -39,7 +39,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/1, param_shape, "y")); builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } @@ -54,7 +54,7 @@ class GpuFtzTest : public GpuCodegenTest { /* parameter_number=*/0, param_shape, "x")); builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x)); - auto hlo_module = CreateNewModuleWithFTZ(ftz_); + auto hlo_module = CreateNewUnverifiedModuleWithFTZ(ftz_); hlo_module->AddEntryComputation(builder.Build()); return hlo_module; } @@ -75,16 +75,16 @@ class GpuFtzDisabledTest : public GpuFtzTest { // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise. TEST_F(GpuFtzEnabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.f32 - CHECK: mul.ftz.f32 - CHECK-NOT: mul.f32 + CHECK-NOT: mul.rn.f32 + CHECK: mul.rn.ftz.f32 + CHECK-NOT: mul.rn.f32 )"); } TEST_F(GpuFtzDisabledTest, MultiplyFtz) { CompileAndVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"( - CHECK-NOT: mul.ftz.f32 - CHECK: mul.f32 - CHECK-NOT: mul.ftz.f32 + CHECK-NOT: mul.rn.ftz.f32 + CHECK: mul.rn.f32 + CHECK-NOT: mul.rn.ftz.f32 )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index a06576df7b874745236a8d9075355a01ec42e777..6814be779e0b02c38e3bc7008f036b845d88cb6f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -51,7 +51,7 @@ TEST_F(GpuIndexTest, CompatibleUseLinearIndex) { builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {5, 7, 2}), HloOpcode::kGe, param_x, param_y)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); // Check the optimized IR as the unoptimized IR contains dead udiv and urem. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 15d1e269cc22b88f5269175084f20600f165011c..a302b582ede3723acd118d2e4a4bb3efdf7a4d0b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -193,6 +193,33 @@ TEST_F(GpuKernelTilingTest, /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) { + const char *const kHloString = R"( + HloModule FusionTransposeWithReverseNotTiled + fused_computation.1 { + arg0 = f32[128,64]{1,0} parameter(0) + copy0 = f32[128,64]{0,1} copy(arg0) + ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0} + } + + ENTRY reverse_break_assumption { + param0 = f32[128,64]{1,0} parameter(0) + ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop, + calls=fused_computation.1 + })"; + + // Check that a call to llvm.nvvm.barrier0 is not generated. + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: tail call void @llvm.nvvm.barrier0() +; CHECK: } +)", + /*match_optimized_ir=*/true); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9..3019215c015a4e0aa094a62424d650ced0de2a0e 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -48,7 +48,7 @@ TEST_F(GpuLdgTest, LdgForParamRead) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -73,7 +73,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { builder.AddInstruction(HloInstruction::CreateTuple({add, square})); std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyPtx(std::move(hlo_module), R"( @@ -95,7 +95,7 @@ TEST_F(GpuLdgTest, LdgForNonParamRead) { // reduce in the foreseeable future. But if that turns out to be wrong, I give // you, future reader, permission to delete this test. TEST_F(GpuLdgTest, NoLdgWhenSharingBuffer) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation* reduce_computation; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 15198865bda98f9718342d5a444a20305f923b48..ca0a78034d7dc83d17ad72202914d95f37ac122b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -47,7 +47,7 @@ TEST_F(GpuNoAliasTest, Concat) { std::unique_ptr computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8..4636f1d9d20b8c213ffadec427b3820a89c68a7f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -85,7 +85,7 @@ TEST_F(GpuUnrollingTest, UnrollFourTimes) { TEST_F(GpuUnrollingTest, UnrollDefaultTimes) { // The default unrolling factor is 4. HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); auto hlo_module = ParseHloString(kAddModule, config).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index 141f3219387940a08ef22cbcc0be0971a14c2cd6..6b2d76764a077dc6cfa3f9ddc6e525ab330323be 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -45,7 +45,7 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + const std::vector& hlo_total_order) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { std::unordered_map hlo_to_thunk; @@ -53,7 +53,7 @@ ThunkSchedule::ThunkSchedule( InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } - for (const HloInstruction* hlo : hlo_total_order) { + for (HloInstruction* hlo : hlo_total_order) { if (hlo_to_thunk.count(hlo)) { thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); } diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index d3352994f845a535233612a17e19107511ce0622..43b628a1baf0e79a3197f3cfad3547991642eaed 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -46,7 +46,7 @@ class ThunkSchedule { public: ThunkSchedule(std::unique_ptr thunks, std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + const std::vector& hlo_total_order); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c552c2925497f1c4808d74a615d35cdbeeba1858 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +namespace { +// The parameter space on the GPU device is limited. We pick an arbitrary low +// constant here to try to prevent exceeding this parameter space. For a proper +// fix, we would have to take into account which parameters share a buffer, and +// how big these buffers are. +constexpr int32 kMaxParameters = 128; + +StatusOr SplitConcatenate(HloInstruction* concat, HloComputation* comp) { + auto operands = concat->operands(); + std::vector operands_to_split(operands.begin(), + operands.end()); + while (operands_to_split.size() > 1) { + std::vector new_operands; + absl::Span operands_span(operands_to_split); + for (int64 offset = 0; offset < operands_to_split.size(); + offset += kMaxParameters) { + // Check if there is a remainder of operands that does not completely fill + // one "batch" of exactly 'kMaxParameters' operands. If there are only + // less than 'kMaxParameters' operands left, then we still put them into a + // concat together. Otherwise, we spare them for another round so that + // they can be put together into a concat with some of the newly created + // concats. + if (offset > 0 && offset + kMaxParameters > operands_to_split.size()) { + new_operands.insert(new_operands.end(), + operands_to_split.begin() + offset, + operands_to_split.end()); + } else { + Shape new_shape = concat->shape(); + int64 concat_dimension_size = 0; + for (int64 i = 0; + i < kMaxParameters && offset + i < operands_to_split.size(); ++i) { + concat_dimension_size += + operands_to_split[i + offset]->shape().dimensions( + concat->concatenate_dimension()); + } + new_shape.set_dimensions(concat->concatenate_dimension(), + concat_dimension_size); + auto new_concat = comp->AddInstruction(concat->CloneWithNewOperands( + new_shape, operands_span.subspan(offset, kMaxParameters))); + new_operands.push_back(new_concat); + } + } + operands_to_split = new_operands; + } + TF_RETURN_IF_ERROR(comp->ReplaceInstruction(concat, operands_to_split[0])); + return true; +} + +std::vector GetRelevantVariadicOps(HloComputation* comp) { + std::vector ops; + for (HloInstruction* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kConcatenate && + instr->operand_count() > kMaxParameters) { + ops.push_back(instr); + } + } + return ops; +} + +} // namespace + +StatusOr VariadicOpSplitter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* op : GetRelevantVariadicOps(comp)) { + // TODO(b/112613927): Handle also other ops than concatenate. + TF_ASSIGN_OR_RETURN(bool result, SplitConcatenate(op, comp)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h new file mode 100644 index 0000000000000000000000000000000000000000..7673ad0d48a04508987025dac84b60e396e3d7dc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// Splits variadic ops with many operands into pieces such that we don't exceed +// the parameter space on the GPU. Currently only concatenate ops are split up. +class VariadicOpSplitter : public HloModulePass { + public: + absl::string_view name() const override { return "variadic-op-splitter"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d00ac4dc7b57664a317157c093d7ffaa01b4fd6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { +using match::Concatenate; + +class VariadicOpSplitterTest : public HloTestBase {}; + +TEST_F(VariadicOpSplitterTest, DontSplit) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + p0 = f16[30,41] parameter(0) + p1 = f16[30,41] parameter(1) + ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} + })") + .ValueOrDie(); + EXPECT_FALSE(VariadicOpSplitter().Run(module.get()).ValueOrDie()); +} + +TEST_F(VariadicOpSplitterTest, SplitInto2) { + auto builder = HloComputation::Builder(TestName()); + auto operand = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({42}))); + std::vector concat_operands(255, operand); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, {255}), concat_operands, 0)); + auto module = CreateNewVerifiedModule(); + auto entry_computation = module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(Match(entry_computation->root_instruction(), + Concatenate().WithNumOperands(128).WithOperand( + 0, Concatenate().WithNumOperands(128)))); +} + +TEST_F(VariadicOpSplitterTest, SplitInto3) { + auto builder = HloComputation::Builder(TestName()); + auto operand = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({42}))); + std::vector concat_operands(256, operand); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, {256}), concat_operands, 0)); + auto module = CreateNewVerifiedModule(); + auto entry_computation = module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(VariadicOpSplitter().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(Match(entry_computation->root_instruction(), + Concatenate(Concatenate().WithNumOperands(128), + Concatenate().WithNumOperands(128)))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 9a61f8ac5a62e38e687a93890eb33481a01d51c8..2dce7749bbd8da2673ae607eee3d731d9917e8fe 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -29,7 +29,7 @@ namespace { class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() - : module_(CreateNewModule()), + : module_(CreateNewVerifiedModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} @@ -69,8 +69,10 @@ class WhileTransformerTest : public HloTestBase { auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( data_shape_, loop_state, data_tuple_index)); // Use 'induction_variable' in computation with no path to output tuple. + auto cast = builder.AddInstruction(HloInstruction::CreateBitcastConvert( + ShapeUtil::MakeShape(F32, {}), induction_variable)); auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); + HloInstruction::CreateBroadcast(data_shape_, cast, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index e30e7667f3015bc7bfe67c65147a5016332780f7..dc40b9446ad1bffcb757543e52fc9ab20de6d52e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,16 +30,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; +class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); @@ -86,7 +86,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module); + HloSchedule schedule(module.get()); schedule.set_sequence(cond_computation, {cond_param, cond_iter, cond_data, cond_lt}); schedule.set_sequence(body_computation, {body_param}); @@ -258,7 +258,7 @@ class HeapSimulatorTracker { // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, - const std::vector& instruction_sequence) { + const std::vector& instruction_sequence) { HloModuleConfig config; module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); @@ -286,7 +286,7 @@ class HeapSimulatorTracker { // Similar to the single entry computation constructor above, but runs the // simulation over the entire module. void RunWholeModule( - const std::vector& full_module_sequence) { + const std::vector& full_module_sequence) { points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -294,7 +294,7 @@ class HeapSimulatorTracker { HloSchedule schedule(module_.get()); absl::flat_hash_map reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { - const HloInstruction* instruction = full_module_sequence[i]; + HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; @@ -351,7 +351,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloVerifiedTestBase { +class HeapSimulatorTest : public HloTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1ea26ddd5b9ee01eaeb812b32539c7820d3d5dda..414c63271245315f037d04924c9291a9cd5b7a77 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 56 +// Next ID: 58 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -51,7 +51,7 @@ message HloInstructionProto { string name = 1; string opcode = 2; - xla.Shape shape = 3; + xla.ShapeProto shape = 3; xla.OpMetadata metadata = 7; @@ -132,7 +132,7 @@ message HloInstructionProto { string custom_call_opaque = 53; // Shape of outfeed request. - xla.Shape outfeed_shape = 29; + xla.ShapeProto outfeed_shape = 29; // Describes the dimension numbers used for a dot operation xla.DotDimensionNumbers dot_dimension_numbers = 30; @@ -184,6 +184,13 @@ message HloInstructionProto { // Sharding for kDomain instructions. xla.OpSharding domain_entry_sharding = 54; xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated xla.ShapeProto operand_shapes_with_layout = 57; } // Serialization of HloComputation. @@ -198,7 +205,8 @@ message HloComputationProto { repeated HloInstructionProto instructions = 2; // The program shape (with layout) of this computation. - xla.ProgramShape program_shape = 4; + + xla.ProgramShapeProto program_shape = 4; // The id of this computation. int64 id = 5; @@ -218,6 +226,67 @@ message HloScheduleProto { map sequences = 1; } +message HloInputOutputAliasProto { + // The following proto describes a pair of aliased an input + // (described by parameter number and a ShapeIndex of the parameter) + // and an output (described by a ShapeIndex of the root + // instruction). For example: + // + // entry = { + // output_shape_index={1}, + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first paremter's {1, 2} element is + // aliased with the {1} element of the root instruction. + message AliasEntryProto { + // ShapeIndex of the root hlo. + repeated int64 output_shape_index = 1; + // Number of the parameter in entry computation. + int64 parameter_number = 2; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 3; + } + + repeated AliasEntryProto entries = 1; +} + +message DynamicParameterBindingProto { + // A list of bindings which indicates that the `target_dim_num` in + // the subshape `target_param_index` of parameter `target_param_num` + // is a dynamic dimension and its real dynamic size is represented + // by `dynamic_param_index` in parameter `dynamic_param_num`. + // + // As an example, imagine we have a program: + // + // ENTRY main { + // a = f32[] parameter(0) + // b = f32[10] parameter(1) + // ROOT root = (f32[], f32[10]) tuple(%a, %b) + // } + // + // Let's say 'b' (param index 1) is a dynamic shape whose input has + // an upperbound of 10 and real size is determined at runtime.'a' + // represents the real size of b's first dimension. + // + // In this case, the fields are set in the following way: + // dynamic_param_num = 1 + // dynamic_param_index = {} + // target_param_num = 0 + // target_param_index = {} + // target_param_dim = 0 + message Binding { + int64 dynamic_param_num = 1; + repeated int64 dynamic_param_index = 2; + int64 target_param_num = 3; + repeated int64 target_param_index = 4; + int64 target_param_dim_num = 5; + } + + repeated Binding entries = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -228,14 +297,19 @@ message HloModuleProto { // callees appear before their callers. repeated HloComputationProto computations = 3; - // The program shape (with layout) of the entry computation. - xla.ProgramShape program_shape = 4; + // The host program shape (with layout) of the entry computation. + xla.ProgramShapeProto host_program_shape = 4; // The id of this module. int64 id = 5; // The schedule for this module. HloScheduleProto schedule = 7; + + // Describes alias information between inputs and outputs. + HloInputOutputAliasProto input_output_alias = 8; + + DynamicParameterBindingProto dynamic_parameter_binding = 9; } // Serialization of LogicalBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index c3da12e273c77793647981f8653649155aac9483..cf8e6594cbe5ffd28ca75dd5006e8817f1e8581c 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -59,8 +59,9 @@ class BufferValueMap { // construction process. using BufferNumber = int64; - explicit BufferValueMap(const HloDataflowAnalysis& dataflow) - : dataflow_(dataflow) { + explicit BufferValueMap(HloModule* module, + const HloDataflowAnalysis& dataflow) + : module_(module), dataflow_(dataflow) { buffers_.reserve(dataflow_.values().size()); value_to_buffer_number_.reserve(dataflow_.values().size()); for (const HloValue* value : dataflow_.values()) { @@ -171,6 +172,42 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } + void ComputeInputOutputAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + // Get parameter value from an aliased_input object. + const auto get_parameter_value = + [this](const std::pair& aliased_input) + -> const HloValue& { + int64 param_number = aliased_input.first; + const ShapeIndex& param_index = aliased_input.second; + return dataflow_.GetUniqueValueAt( + module_->entry_computation()->parameter_instruction(param_number), + param_index); + }; + + // If the value shows up in a root instruction, alias it with parameter + // intruction. + for (const HloPosition& pos : value.positions()) { + if (pos.instruction == module_->entry_computation()->root_instruction()) { + ShapeIndex output_index = pos.index; + + auto aliased_input = + module_->input_output_alias_config().GetAliasedParameter( + output_index); + if (aliased_input) { + aliased_buffers->push_back( + GetBufferForValue(get_parameter_value(*aliased_input))); + } + } + } + + // If the value is parameter instruction itself, alias it with itself. + if (value.instruction()->opcode() == HloOpcode::kParameter && + value.instruction()->parent() == module_->entry_computation()) { + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + void ComputeWhileAliasedBuffers(const HloValue& value, std::vector* aliased_buffers) { VLOG(3) << "Compute kWhile aliases"; @@ -278,6 +315,7 @@ class BufferValueMap { VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; } std::vector aliased_buffers; + ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. @@ -288,6 +326,8 @@ class BufferValueMap { return aliased_buffers; } + HloModule* module_; + // Dataflow analysis used to construct the buffer map. const HloDataflowAnalysis& dataflow_; @@ -461,7 +501,7 @@ StatusOr> HloAliasAnalysis::Run( /*bitcast_defines_value=*/false, fusion_can_share_buffer)); - BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis()); buffer_map.MergeAliasedBuffers(); // Create a vector of HloBuffers, one for each set of values in the diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0cd0ab36fcf832af9a71ab5837c94f9b39bc4bf3..7e6150e94153cd15463725e862ce1b8593f2c991 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,17 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloVerifiedTestBase { +class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : HloVerifiedTestBase() { - module_ = CreateNewModule(); + HloAliasAnalysisTest() : HloTestBase() { + module_ = CreateNewVerifiedModule(); } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_, + analysis_ = HloAliasAnalysis::Run(module_.get(), /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -93,7 +93,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -110,7 +110,7 @@ class HloAliasAnalysisTest : public HloVerifiedTestBase { return false; } - HloModule* module_; + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } +TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) { + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + + auto negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) { + // parameter 0 aliased with output 1 and parameter 1 aliased with output 0. + // + // (p0 , p1) + // \ / + // \ / + // alias X + // / \ + // / \ + // (p0 , p1) + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0})); + + // Cannot alias an output twice. + ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + // Every Ops in this graph are aliased with each other. + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte0), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); + + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{0})); + EXPECT_EQ(analysis.GetUniqueBufferAt(gte1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{1})); +} + +TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) { + // Test a simple single while instruction can be aliased with input and output + // of the computation. + // + // body((F32[], F32[]) %tuple_param): + // %add = Add(%tuple_param{0}, %tuple_param{1}) + // return Tuple(%tuple_param{0}, %add) + // + // condition((F32[], F32[]) %tuple_param): + // return Constant(false) + // + // entry: + // %param1 = param1 + // %while = While(%param1, body, condition) + // %while_1 = GTE(%while, 0) + // %while_2 = GTE(%while, 1) + // %negate_1 = Negate(%while_1) + // %negate_2 = Negate(%while_2) + // return Tuple(negate_1, negate_2) + // + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Element 0 passes transparently through the body. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); + auto body_tuple = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_0, add})); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + // Condition computation trivially returns a constant "false". + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(tuple_shape, condition, body, param)); + auto while_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0)); + auto while_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1)); + auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_1)); + auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, while_element_2)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2})); + module_->AddEntryComputation(builder.Build()); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0})); + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1})); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_THAT( + GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), + UnorderedElementsAre(GetValueDefinedAt(param, {1}), + GetValueDefinedAt(xla_while, /*index=*/{1}), + GetValueDefinedAt(body_param, {1}), + GetValueDefinedAt(cond_param, {1}), + GetValueDefinedAt(add), + GetValueDefinedAt(negate_2))); + + EXPECT_THAT( + analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(), + UnorderedElementsAre( + HloPosition{param, {1}}, HloPosition{xla_while, {1}}, + HloPosition{while_element_2, {}}, HloPosition{body_param, {1}}, + HloPosition{body_element_1, {}}, HloPosition{add, {}}, + HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}}, + HloPosition{cond_param, {1}}, HloPosition{negate_2, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); +} + TEST_F(HloAliasAnalysisTest, SingleCall) { // Test a single call of a subcomputation. The subcomputation adds its two // array-shaped parameters. @@ -463,7 +638,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_).status()); + TF_ASSERT_OK(flattener.Run(module_.get()).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -837,7 +1012,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -879,13 +1054,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_); + DependencyHloOrdering ordering(module_.get()); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } // For a sequential order, if there is interference iff the negate is after // the while. - HloSchedule schedule(module_); + HloSchedule schedule(module_.get()); schedule.set_sequence(body, {body_param, body_root}); schedule.set_sequence(condition, {cond_param, cond_root}); { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c2041c466708fd8c88d34f14fbc0064905f594a9..ff122b529bdcdcc69d2245136e19101902dbf957 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -215,7 +216,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( if (removed.count(item) != 0 || item->user_count() != 0 || item == root_instruction() || !IsRemovable(item) || - item->HasSideEffect()) { + (item->HasSideEffect() && item != instruction)) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -321,7 +322,7 @@ void HloComputation::ComputeInstructionPostOrder( // Add the operands to the stack in reverse order so the first operand is // processed first. This will produce a more natural ordering and a nicer - // result for thigns like HLO stringification. + // result for things like HLO stringification. const auto& operands = current->operands(); for (int64 i = operands.size() - 1; i >= 0; --i) { dfs_stack.emplace_back(operands[i]); @@ -498,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const { proto.add_instructions()->Swap(&instruction_proto); } proto.set_root_id(root_instruction()->unique_id()); - *proto.mutable_program_shape() = ComputeProgramShape(); + *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); return proto; } @@ -710,6 +711,8 @@ bool HloComputation::operator==(const HloComputation& other) const { return eq(root_instruction(), other.root_instruction()); } +uint64 HloComputation::Hash() const { return root_instruction()->Hash(); } + Status HloComputation::ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { @@ -739,72 +742,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, return RemoveInstructionAndUnusedOperands(old_instruction); } -std::unique_ptr HloComputation::ComputeReachability() - const { - const auto& all = MakeInstructionPostOrder(); - auto result = absl::make_unique(all); - auto channel_dependency_map = ComputeChannelDependencies(); - - std::vector inputs; - for (const HloInstruction* hlo : all) { - inputs.assign(hlo->operands().begin(), hlo->operands().end()); - inputs.insert(inputs.end(), hlo->control_predecessors().begin(), - hlo->control_predecessors().end()); - - switch (hlo->opcode()) { - case HloOpcode::kRecvDone: { - auto it = channel_dependency_map.find(hlo->channel_id()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - break; - } - case HloOpcode::kCrossReplicaSum: { - auto all_reduce_id = hlo->all_reduce_id(); - if (all_reduce_id) { - auto it = channel_dependency_map.find(all_reduce_id.value()); - if (it != channel_dependency_map.end()) { - absl::c_copy(it->second, std::back_inserter(inputs)); - } - } - break; - } - default: - break; - } - - result->FastSetReachabilityToUnion(inputs, hlo); - } - return result; -} - -void HloComputation::UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map) { - std::queue worklist; - worklist.push(instruction); - - std::vector inputs; - - while (!worklist.empty()) { - const HloInstruction* item = worklist.front(); - worklist.pop(); - - inputs.assign(item->operands().begin(), item->operands().end()); - inputs.insert(inputs.end(), item->control_predecessors().begin(), - item->control_predecessors().end()); - - if (reachability_map->SetReachabilityToUnion(inputs, item)) { - // Add immediate successors to worklist. - for (const HloInstruction* user : item->users()) { - worklist.push(user); - } - for (const HloInstruction* succ : item->control_successors()) { - worklist.push(succ); - } - } - } -} - std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; for (auto* instruction : instructions()) { @@ -860,7 +797,7 @@ Status HloComputation::AcceptWithOperandOrder( template Status HloComputation::AcceptOrdered( DfsHloVisitorBase* visitor, - const std::vector& order) const { + const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) @@ -890,9 +827,9 @@ Status HloComputation::AcceptOrdered( // Explicit instantiations. template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, const std::vector&) const; + DfsHloVisitor*, const std::vector&) const; template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, const std::vector&) const; + ConstDfsHloVisitor*, const std::vector&) const; Status HloComputation::Accept( const std::function& visitor_func) { @@ -911,14 +848,46 @@ std::unique_ptr HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - /*extras=*/{}, context, suffix); + context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + return CloneWithReplacements(std::move(replacements), context, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context, const string& suffix) { + std::unordered_map> + replacements; + replacements.emplace(std::move(r1)); + replacements.emplace(std::move(r2)); + replacements.emplace(std::move(r3)); + return CloneWithReplacements(std::move(replacements), context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context, - const string& suffix) { + HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { context_ptr = absl::make_unique(parent(), suffix); @@ -939,18 +908,50 @@ std::unique_ptr HloComputation::CloneWithReplacements( }; VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + + // We want to do a postorder walk over [replace(i) for i in instructions_]. + // We can't reuse MakeInstructionPostOrder() for this, because that will + // generate a postorder of plain instructions_, and our replacements may + // change the postorder! + // + // The postorder we want here is simpler than what MakeInstructionPostOrder() + // does -- we only care about operand dependencies -- so let's just do it + // ourselves. std::vector postorder; - for (HloInstruction* instr : extras) { - postorder.push_back(instr); - } - for (HloInstruction* instr : MakeInstructionPostOrder()) { - if (HloInstruction* replacement = replace(instr)) { - postorder.push_back(replacement); + absl::flat_hash_map visited; + for (const auto& instr : instructions_) { + std::vector dfs_stack; + HloInstruction* new_instr = replace(instr.get()); + if (!new_instr) { + continue; + } + dfs_stack.push_back(new_instr); + + while (!dfs_stack.empty()) { + auto* cur = dfs_stack.back(); + auto it = visited.find(cur); + if (it != visited.end()) { + dfs_stack.pop_back(); + if (it->second == kVisited) { + continue; + } + CHECK_EQ(it->second, kVisiting); + postorder.push_back(cur); + it->second = kVisited; + continue; + } + + visited.insert({cur, kVisiting}); + for (HloInstruction* operand : cur->operands()) { + HloInstruction* new_operand = replace(operand); + if (new_operand) { + dfs_stack.emplace_back(new_operand); + } + } } } std::vector> instructions; - std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { @@ -960,9 +961,8 @@ std::unique_ptr HloComputation::CloneWithReplacements( << operand->ToString() << ", used by " << instr->ToString(); new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, context); - instructions.push_back(std::move(new_instr)); + instructions.push_back( + instr->CloneWithNewOperands(instr->shape(), new_operands, context)); } Builder builder(name() + "." + suffix); for (auto& instr : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index d87ab4bda162a74421e8906e07cfcb97e2128fe4..c584e4c7ca5770533f28352b0df9dadd9dbe1860 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" @@ -128,9 +127,10 @@ class HloComputation { // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); - // Remove an instruction from the computation and also transitively any - // operand that has no users post removing an instruction. The instruction - // must have no users. Instruction is deallocated with this call. + // Remove an instruction (including side effecting ones) from the computation + // and also transitively any operand that has no side effect and no users post + // removing an instruction. The instruction must have no users. Instruction is + // deallocated with this call. Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction @@ -214,19 +214,6 @@ class HloComputation { // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - std::unique_ptr ComputeReachability() const; - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction( - const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this @@ -277,6 +264,12 @@ class HloComputation { // Return whether `*this` and `other` are functionally equivalent. bool operator==(const HloComputation& other) const; + // Generates a hash value of an HLO computation. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO computations, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Replaces old instruction with newly created instruction. Removes old // instruction from computation. Updates uses and root instruction. Status ReplaceWithNewInstruction( @@ -314,7 +307,7 @@ class HloComputation { // be a topological sort of all instructions in the computation. template Status AcceptOrdered(DfsHloVisitorBase* visitor, - const std::vector& order) const; + const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. Status Accept(const std::function& visitor_func); @@ -332,14 +325,38 @@ class HloComputation { // the map's value to replace that instruction in the cloned computation. // // If replacements maps a key to nullptr, we remove that instruction from the - // new computation. - // If additional instructions are used by instructions in replacement map, - // they must be passed in post-order in the extras span. + // new computation. If an element of `replacements` references an instruction + // that's not already in the computation, it's cloned and added to the new + // computation. + // + // All relevant instructions are cloned, *including* unique_ptr in the + // `replacements` map. std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - absl::Span extras, HloCloneContext* context = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); + + // Convenience overloads for CloneWithReplacements. You want to do + // + // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR + // + // but that doesn't work because std::initializer_list is not movable. These + // overloads let you do + // + // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK + // + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + HloCloneContext* context = nullptr, const string& suffix = "clone"); + std::unique_ptr CloneWithReplacementPairs( + std::pair> r1, + std::pair> r2, + std::pair> r3, + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -354,6 +371,14 @@ class HloComputation { // channel complete). bool IsRemovable(const HloInstruction* instruction); + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + absl::flat_hash_map>; + ChannelDependencyMap ComputeChannelDependencies() const; + // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. bool HasSideEffect() const; @@ -409,14 +434,6 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - // Returns a map from channel-id to directed dependencies of the channel - // instructions. For send&recv pairs it means the send instruction and for - // cross-replica-sum the union of the dependencies for all participating - // instructions. - using ChannelDependencyMap = - absl::flat_hash_map>; - ChannelDependencyMap ComputeChannelDependencies() const; - enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 2aaaef1d36d58bcce18db4aa37ff05ea352e484b..0361c87428f6e4c031d95492a5bc782ad388e5b5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,19 +20,19 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = match; using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; @@ -65,7 +65,7 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); @@ -73,7 +73,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map_computation = @@ -85,7 +85,7 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto negate_computation = module->AddEmbeddedComputation(CreateNegateComputation()); auto map1_computation = @@ -119,7 +119,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -134,7 +134,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); @@ -151,7 +151,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -170,7 +170,7 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); @@ -192,7 +192,7 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); @@ -217,7 +217,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { @@ -257,11 +257,11 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_THAT(copy, op::Copy(constant)); + EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant)))); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -274,12 +274,13 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } @@ -297,7 +298,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { ShapeTree indices_to_copy(constant->shape(), /*init_value=*/true); EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy) .ValueOrDie(), - op::Copy(constant)); + GmockMatch(m::Copy(m::Op().Is(constant)))); } { @@ -330,10 +331,11 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::Copy(op::GetTupleElement(tuple)))); - EXPECT_THAT(deep_copy, op::Tuple(copies_added.element({0}), - copies_added.element({1}))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({0})), + m::Copy(m::GetTupleElement(m::Op().Is(tuple))) + .Is(copies_added.element({1}))))); } { @@ -346,8 +348,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::GetTupleElement(tuple), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, + GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) == nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -363,8 +366,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added) .ValueOrDie(); - EXPECT_THAT(deep_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), - op::GetTupleElement(tuple))); + EXPECT_THAT(deep_copy, GmockMatch(m::Tuple( + m::Copy(m::GetTupleElement(m::Op().Is(tuple))), + m::GetTupleElement(m::Op().Is(tuple))))); EXPECT_TRUE(copies_added.element({}) == nullptr); EXPECT_TRUE(copies_added.element({0}) != nullptr); EXPECT_TRUE(copies_added.element({1}) == nullptr); @@ -376,12 +380,12 @@ TEST_F(HloComputationTest, DeepCopyToken) { // copied. auto builder = HloComputation::Builder(TestName()); auto token = builder.AddInstruction(HloInstruction::CreateToken()); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); // No copy should be added. - EXPECT_THAT(copy, op::AfterAll()); + EXPECT_THAT(copy, GmockMatch(m::AfterAll())); } TEST_F(HloComputationTest, DeepCopyTokenTuple) { @@ -393,14 +397,15 @@ TEST_F(HloComputationTest, DeepCopyTokenTuple) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); // Only the array (second tuple element) should be copied. The token is passed // through transparently. - EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), - op::Copy(op::GetTupleElement(tuple)))); + EXPECT_THAT(copy, GmockMatch(m::Tuple( + m::GetTupleElement(m::Op().Is(tuple)), + m::Copy(m::GetTupleElement(m::Op().Is(tuple)))))); } TEST_F(HloComputationTest, CycleDetection) { @@ -412,7 +417,7 @@ TEST_F(HloComputationTest, CycleDetection) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); @@ -440,16 +445,18 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Negate(m::Op().Is(constant)))); EXPECT_EQ(negate, computation->root_instruction()); } @@ -466,7 +473,7 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); @@ -484,107 +491,6 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); } -TEST_F(HloComputationTest, Reachability) { - // Test reachability of a non-trivial computation: - // - // const1 const2 - // | | - // | +-------+ - // | | | - // add .. negate - // | . | - // | .... exp - // | | - // +---+ +-+---+ - // | | | - // multiply copy - // - // There is a control dependency from 'add' to 'exp'. - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant2)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, negate)); - auto mul = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, add, exp)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - - auto module = CreateNewModule(); - auto computation = - module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); - - TF_CHECK_OK(add->AddControlDependencyTo(exp)); - auto reachability = computation->ComputeReachability(); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_TRUE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_TRUE(reachability->IsReachable(constant1, copy)); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_TRUE(reachability->IsReachable(constant2, negate)); - EXPECT_TRUE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_TRUE(reachability->IsReachable(constant2, copy)); - - EXPECT_FALSE(reachability->IsReachable(exp, constant1)); - EXPECT_FALSE(reachability->IsReachable(exp, constant2)); - EXPECT_FALSE(reachability->IsReachable(exp, add)); - EXPECT_FALSE(reachability->IsReachable(exp, negate)); - EXPECT_TRUE(reachability->IsReachable(exp, exp)); - EXPECT_TRUE(reachability->IsReachable(exp, mul)); - EXPECT_TRUE(reachability->IsReachable(exp, copy)); - - EXPECT_FALSE(reachability->IsReachable(mul, constant1)); - EXPECT_FALSE(reachability->IsReachable(mul, constant2)); - EXPECT_FALSE(reachability->IsReachable(mul, add)); - EXPECT_FALSE(reachability->IsReachable(mul, negate)); - EXPECT_FALSE(reachability->IsReachable(mul, exp)); - EXPECT_TRUE(reachability->IsReachable(mul, mul)); - EXPECT_FALSE(reachability->IsReachable(mul, copy)); - - EXPECT_TRUE(reachability->IsConnected(constant1, copy)); - EXPECT_TRUE(reachability->IsConnected(copy, constant1)); - EXPECT_FALSE(reachability->IsConnected(negate, add)); - EXPECT_FALSE(reachability->IsConnected(add, negate)); - - // Remove the control dependency then update and verify the reachability map - ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); - computation->UpdateReachabilityThroughInstruction(exp, reachability.get()); - - EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); - EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant1, add)); - EXPECT_FALSE(reachability->IsReachable(constant1, negate)); - EXPECT_FALSE(reachability->IsReachable(constant1, exp)); - EXPECT_TRUE(reachability->IsReachable(constant1, mul)); - EXPECT_FALSE(reachability->IsReachable(constant1, copy)); - - // Change a use within the graph then update and verify the reachability map - ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); - computation->UpdateReachabilityThroughInstruction(negate, reachability.get()); - - EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); - EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); - EXPECT_TRUE(reachability->IsReachable(constant2, add)); - EXPECT_FALSE(reachability->IsReachable(constant2, negate)); - EXPECT_FALSE(reachability->IsReachable(constant2, exp)); - EXPECT_TRUE(reachability->IsReachable(constant2, mul)); - EXPECT_FALSE(reachability->IsReachable(constant2, copy)); -} - TEST_F(HloComputationTest, Stringification) { const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); @@ -606,7 +512,7 @@ TEST_F(HloComputationTest, Stringification) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -641,7 +547,7 @@ TEST_F(HloComputationTest, StringificationIndent) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = @@ -677,7 +583,7 @@ TEST_F(HloComputationTest, StringificationCanonical) { 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto options = HloPrintOptions().set_print_metadata(false); @@ -700,27 +606,5 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -TEST_F(HloComputationTest, ChannelReachability) { - const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); - HloComputation::Builder builder("ChannelReachability"); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); - auto send = - builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); - auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); - auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); - auto recv = - builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); - auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); - - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(recv_done)); - auto reachability = computation->ComputeReachability(); - EXPECT_TRUE(reachability->IsReachable(param, recv_done)); - EXPECT_FALSE(reachability->IsReachable(send, recv)); - EXPECT_FALSE(reachability->IsReachable(send_done, recv)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 4f898ce61c3f36e83e4b13130a404dbb4a2c36c6..5e37883d3d8d5067bab873ac6b5f732e7360c5fa 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -52,8 +52,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, and AfterAll operation. - // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. + // Skip Constant, Parameter, Tuple, AfterAll operation. + // Tuple constants are not directly supported by any backends, hence + // folding Tuple is not useful and would in fact be expanded back into + // kTuple by Algebraic Simplifier. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this // special case is not necessary. @@ -63,6 +65,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { instruction->opcode() == HloOpcode::kAfterAll) { continue; } + // Skip instructions with non-constant operands. if (!hlo_query::AllOperandsAreConstants(*instruction)) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index e45f905f7152c37a9ab2b41d407310671310c2a3..4f81dc94e577a63c09ae4019e5e8158252c712ce 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -22,22 +22,23 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { -using HloConstantFoldingTest = HloVerifiedTestBase; +namespace m = xla::match; + +using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -46,16 +47,17 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } @@ -67,16 +69,17 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } @@ -88,16 +91,17 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input)))); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); - EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant())); EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } @@ -130,15 +134,15 @@ TEST_F(HloConstantFoldingTest, Concatenate) { Shape shape = ShapeUtil::MakeShape(F32, dimensions); builder.AddInstruction(HloInstruction::CreateConcatenate( shape, operands, test_config.concat_dimension)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } } @@ -157,15 +161,15 @@ TEST_F(HloConstantFoldingTest, Slice) { Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); builder.AddInstruction(HloInstruction::CreateSlice( shape, literal_instruction, slice_start, slice_limits, slice_strides)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } @@ -182,15 +186,15 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { const int64 permutation[] = {1, 2, 0, 4, 3}; builder.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Constant()); + EXPECT_THAT(root, GmockMatch(m::Constant())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; @@ -219,27 +223,29 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - ParseAndVerifyModule(kConstantFoldReduce); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_TRUE(result); - EXPECT_EQ(6, module() - .entry_computation() + EXPECT_EQ(6, m->entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - ParseAndVerifyModule(kConstantFoldReduce); - HloInstruction* add = module().computations().begin()->root_instruction(); + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnVerifiedModule(kConstantFoldReduce)); + HloInstruction* add = m->computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get())); EXPECT_FALSE(result); - EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } const char* const kConstantFoldLargePad = R"( @@ -259,7 +265,7 @@ TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { EXPECT_FALSE(result); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Pad(op::Constant(), op::Constant())); + GmockMatch(m::Pad(m::Constant(), m::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index a502fff9a0f1e40065746f2193bf76b1adefdb31..df7d3826dbad1f264a5dc53312c062900155b0f6 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -269,7 +269,7 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(map->to_apply())); + ProcessNestedSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64 element_count = ShapeUtil::ElementsIn(map->shape()); @@ -285,7 +285,7 @@ Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. // This counts the number of times the reduction function is applied, so it @@ -311,7 +311,7 @@ Status HloCostAnalysis::HandleReduceWindow( auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessNestedSubcomputation(function)); // Compute the cost of all elements for this ReduceWindow operation. For each // output element there are window_size - 1 reductions to perform. @@ -336,9 +336,9 @@ Status HloCostAnalysis::HandleSelectAndScatter( // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, - ProcessSubcomputation(instruction->select())); + ProcessNestedSubcomputation(instruction->select())); TF_ASSIGN_OR_RETURN(const Properties scatter_properties, - ProcessSubcomputation(instruction->scatter())); + ProcessNestedSubcomputation(instruction->scatter())); // Compute the cost of all elements for this operation. For each scatter // source element there are window_size - 1 select computations to perform and @@ -419,6 +419,21 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { } Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + +Status HloCostAnalysis::HandleAddDependency( + const HloInstruction* add_dependency) { + // This instruction is used to enforce ordering at compile time. No code is + // emitted. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -574,7 +589,7 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) { Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { TF_ASSIGN_OR_RETURN( current_properties_, - ProcessSubcomputation(fusion->fused_instructions_computation())); + ProcessNestedSubcomputation(fusion->fused_instructions_computation())); // Fusion nodes that produce a tuple also produce the entries in the tuple. // Ignore the memory accessed inside fused ops, since fusion is supposed to @@ -595,7 +610,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, - ProcessSubcomputation(call->to_apply())); + ProcessUnnestedSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } @@ -624,13 +639,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. - // - // TODO(b/26346211): Improve the cost analysis for while nodes. TF_ASSIGN_OR_RETURN(const Properties body_properties, - ProcessSubcomputation(xla_while->while_body())); + ProcessUnnestedSubcomputation(xla_while->while_body())); - TF_ASSIGN_OR_RETURN(const Properties condition_properties, - ProcessSubcomputation(xla_while->while_condition())); + TF_ASSIGN_OR_RETURN( + const Properties condition_properties, + ProcessUnnestedSubcomputation(xla_while->while_condition())); current_properties_.clear(); for (const auto& property : body_properties) { @@ -647,10 +661,12 @@ Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { // Compute the cost of the true and false computations and take the maximum // from those for each property. - TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, - ProcessSubcomputation(conditional->true_computation())); - TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, - ProcessSubcomputation(conditional->false_computation())); + TF_ASSIGN_OR_RETURN( + const Properties true_computation_properties, + ProcessUnnestedSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN( + const Properties false_computation_properties, + ProcessUnnestedSubcomputation(conditional->false_computation())); current_properties_ = true_computation_properties; for (const auto& property : false_computation_properties) { if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { @@ -664,12 +680,33 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { } Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather doesn't read the whole input buffer, it's equivalent to a copy the + // size of the output shape and a read of the gather indices. + current_properties_[kBytesAccessedKey] = + GetShapeSize(gather->shape()) * 2 + + GetShapeSize(gather->operand(1)->shape()); // Gather does not issue any flops. return Status::OK(); } Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { - // TODO(b/32945756): Compute the properties of the sub-computation. + current_properties_[kBytesAccessedKey] = + GetShapeSize(scatter->operand(2)->shape()) * 2 + + GetShapeSize(scatter->operand(1)->shape()); + const int64 element_count = + ShapeUtil::ElementsIn(scatter->operand(2)->shape()); + TF_ASSIGN_OR_RETURN(const Properties sub_properties, + ProcessNestedSubcomputation(scatter->to_apply())); + for (const auto& property : sub_properties) { + if (property.first != kBytesAccessedKey) { + current_properties_[property.first] = property.second * element_count; + } + } + return Status::OK(); +} + +Status HloCostAnalysis::HandleGetDimensionSize( + const HloInstruction* /*get_size*/) { return Status::OK(); } @@ -709,10 +746,19 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } -StatusOr HloCostAnalysis::ProcessSubcomputation( - HloComputation* computation) { +StatusOr +HloCostAnalysis::ProcessNestedSubcomputation(HloComputation* computation) { + HloCostAnalysis visitor(shape_size_, per_second_rates_); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.properties(); +} + +StatusOr +HloCostAnalysis::ProcessUnnestedSubcomputation(HloComputation* computation) { HloCostAnalysis visitor(shape_size_, per_second_rates_); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + hlo_properties_.insert(visitor.hlo_properties_.begin(), + visitor.hlo_properties_.end()); return visitor.properties(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 46b4bbeef222e6de581360fc01b293e812f1dedd..33983119c9b00a248c0e8dcc5815c6367192dca3 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -101,12 +101,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleAddDependency(const HloInstruction* add_dependency) override; Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; Status HandleScatter(const HloInstruction* scatter) override; + Status HandleGetDimensionSize(const HloInstruction* get_size) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; @@ -153,7 +155,24 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Returns the properties computed from visiting the computation rooted at the // given hlo. - StatusOr ProcessSubcomputation(HloComputation* computation); + // + // The difference between ProcessNestedSubcomputation and + // ProcessUnnestedSubcomputation is that we expect to get profile results for + // an unnested subcomputation's individual instructions, while we expect that + // a nested subcomputation is completely subsumed by its parent. + // + // For example, subcomputations inside kFusion and kMap are considered nested, + // while subcomputations inside kWhile and kConditional are considered + // unnested. + // + // Another way of thinking of this is, kFusion is implemented on the GPU + // backend using just one GPU kernel, while kWhile's body is implemented as a + // sequence of kernels, one for each HLO therein. Backends don't necessarily + // need to follow this same implementation strategy, but we assume they do for + // the purposes of this platform-generic cost analysis. + StatusOr ProcessNestedSubcomputation(HloComputation* computation); + StatusOr ProcessUnnestedSubcomputation( + HloComputation* computation); // Utility function to handle all element-wise operations. Status HandleElementwiseOp(const HloInstruction* hlo_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index d76ce9ecbca67ae3bc3db4ee2452f30ccec5b88b..ff32faf298dd1f04c5b769f2a88f76a7a1e18ae7 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -387,7 +387,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -429,7 +429,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -472,7 +472,7 @@ TEST_F(DomainCostAnalysis, DomainCost) { auto domain = builder.AddInstruction( HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEntryComputation(builder.Build()); EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); @@ -556,5 +556,56 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { EXPECT_EQ(analysis.bytes_accessed(), 8); } +TEST_F(HloCostAnalysisTest, Gather) { + // Test the analysis on a gather. + XlaBuilder builder("gather"); + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); + dim_numbers.set_index_vector_dim(1); + Gather(operand, indices, dim_numbers, {1, 3}); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 56); +} + +TEST_F(HloCostAnalysisTest, Scatter) { + // Test the analysis on a scatter. + XlaBuilder builder("scatter"); + Shape operand_shape = ShapeUtil::MakeShape(F32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + Shape values_shape = ShapeUtil::MakeShape(F32, {2, 3}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + auto values = Parameter(&builder, 2, values_shape, "values"); + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(1); + dim_numbers.add_update_window_dims(1); + dim_numbers.add_inserted_window_dims(0); + dim_numbers.add_scatter_dims_to_operand_dims(0); + Scatter(operand, indices, values, add_, dim_numbers); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 4 * (2 + 2 * (2 * 3))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8..aaa9ec60eb3c4e0159ed40b37d772e0973d306ec 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,22 +19,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -class HloCreationUtilsTest : public HloVerifiedTestBase { +class HloCreationUtilsTest : public HloTestBase { protected: - HloModule* CreateModuleWithProgramShape( + std::unique_ptr CreateModuleWithProgramShape( PrimitiveType primitive_type, absl::Span input_shape_dims, absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - auto module = CreateNewModule("test"); + auto module = CreateNewVerifiedModule("test"); *entry_computation = module->AddEntryComputation( CreateComputationWithSignature({&input_shape}, output_shape, "entry") .ValueOrDie()); @@ -47,10 +47,9 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); @@ -67,9 +66,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + auto module = CreateModuleWithProgramShape( + S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, @@ -92,10 +90,9 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{1, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -113,10 +110,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -134,10 +130,9 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{1, 1}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -154,10 +149,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, - &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6}, + /*output_shape_dims=*/{3, 1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, ExpandFirstDimIntoNDims(param, {3, 1, 2})); @@ -176,10 +170,9 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{2}, - /*output_shape_dims=*/{6}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, ¶m, + &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -197,10 +190,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(S32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -218,10 +210,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - HloModule* module = CreateModuleWithProgramShape(F32, - /*input_shape_dims=*/{}, - /*output_shape_dims=*/{2, 2}, - ¶m, &entry_computation); + auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9b18b0284f63c25934c1b7118dc8973caa62cadc..1eb0260468c4560985027947e89c62cc21139e7e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloVerifiedTestBase { +class HloCseTest : public HloTestBase { protected: HloCseTest() {} }; @@ -59,13 +59,13 @@ TEST_F(HloCseTest, CombineTwoConstants) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); @@ -89,14 +89,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); @@ -121,14 +121,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { shape_r0, HloOpcode::kAdd, root, constants[i])); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -201,7 +201,7 @@ TEST_F(HloCseTest, NonscalarConstants) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( {common_constant1, common_constant2, uncommon_constant})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -233,14 +233,14 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -277,21 +277,21 @@ index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1) f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -327,20 +327,20 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 - } - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -369,22 +369,21 @@ condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 = f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2), condition=%condition.1, body=%body - } - - )"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -411,13 +410,14 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })"); + })"; - auto computation = module().entry_computation(); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto computation = m->entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); + EXPECT_FALSE(cse.Run(m.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -439,14 +439,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -470,14 +470,14 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -488,7 +488,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { TEST_F(HloCseTest, FusionInternalCSE) { // Test that we can CSE expressions that live within a fusion node // computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); @@ -512,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -554,14 +554,14 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module).ValueOrDie()); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -586,7 +586,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); @@ -595,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -607,7 +607,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // rng_function is an impure function because it does RNG. HloComputation* rng_function = nullptr; @@ -649,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -659,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule m add_computation { @@ -680,11 +680,12 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })"); + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -697,19 +698,19 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - ParseAndVerifyModule(R"( + const char* const hlo_string = R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -730,11 +731,12 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})"); +})"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); - const HloInstruction* sub = module().entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(m.get()).ValueOrDie()); + const HloInstruction* sub = m->entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index c22adcdd8dd936eebca3a8f0d85b1254401b5ef4..3ed3d3c11c71dc534f193ba3ffb556b0eb0c80e4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const HloValue& HloDataflowAnalysis::GetValueDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const { - CHECK(ValueIsDefinedAt(instruction, index)); + CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString(); return GetUniqueValueAt(instruction, index); } @@ -466,6 +466,21 @@ bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { return changed; } +bool HloDataflowAnalysis::UpdateAddDependencyValueSet( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency); + const InstructionValueSet& operand_set = + GetInstructionValueSet(add_dependency->operand(0)); + InstructionValueSet& add_dependency_set = + GetInstructionValueSet(add_dependency); + if (operand_set != add_dependency_set) { + add_dependency_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -622,6 +637,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( HloInstruction* instruction) { // Recompute from operands. switch (instruction->opcode()) { + case HloOpcode::kAddDependency: + return UpdateAddDependencyValueSet(instruction); case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); case HloOpcode::kDomain: @@ -795,6 +812,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kAddDependency: case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: @@ -1048,6 +1066,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index abac398c04fc4c418d8814a0097db4434bc1cd9c..ece17fc4c3ea0261474df5d53c088dd05016e1e4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -193,6 +193,7 @@ class HloDataflowAnalysis { bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); + bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); // Propagate the dataflow through the module. void Propagate(); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 510d6360a1cf94ef06d2ed919a57c7a825886834..f7a1f19a6f52befd58a405d0e406d7d0d37a8e57 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -43,7 +43,7 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(CreateNewModule()) {} + HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. @@ -1877,6 +1877,30 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { } } +TEST_P(HloDataflowAnalysisTest, AddDependency) { + string module_string = R"( +HloModule AddDependency +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr analysis, + HloDataflowAnalysis::Run(*module)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency); + + // The after-all and parameter should define a value. Add-dependency should + // not. + EXPECT_EQ(analysis->values().size(), 2); + EXPECT_FALSE(analysis->ValueIsDefinedAt(root)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); @@ -1884,7 +1908,7 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, class HloDataflowAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -2283,6 +2307,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -2308,7 +2370,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); @@ -2437,7 +2500,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -2472,7 +2535,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 3b5cde2996c4195ef458662cd21de85a832d8d55..1fa4259a3e42286cbc911907eea563e6ca6f8611 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -59,7 +59,7 @@ TEST_F(HloDceTest, NoDeadCode) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); @@ -80,7 +80,7 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) { HloInstruction::CreateSend(constant, token, /*channel_id=*/0)); builder.AddInstruction(HloInstruction::CreateTuple({})); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); @@ -110,7 +110,7 @@ TEST_F(HloDceTest, DeadParameters) { builder.AddInstruction(HloInstruction::CreateUnary( live_param->shape(), HloOpcode::kNegate, live_param)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); @@ -150,7 +150,7 @@ TEST_F(HloDceTest, ControlDependencies) { builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency between two instructions. @@ -175,7 +175,7 @@ TEST_F(HloDceTest, ControlDependencies) { // Tests that a dead call instruction is removed. TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Called computation for the call instruction. @@ -215,7 +215,7 @@ TEST_F(HloDceTest, DeadInstructionWithCalledComputation) { // Tests that a while instruction with an infeed (effectul instruction) in its // body is not removed, even its user count is 0. TEST_F(HloDceTest, CalledComputationWithSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Condition computation of a while instruction. @@ -270,7 +270,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) { // Tests that a nested call instruction with a side effect is not removed. TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {}); // Nested called computation with a side effect. @@ -323,7 +323,7 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { } TEST_F(HloDceTest, RemoveDeadSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); @@ -364,7 +364,7 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) { } TEST_F(HloDceTest, KeepUsedSubcomputation) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); HloComputation::Builder builder(TestName()); HloComputation::Builder subcomp_builder("reduction_subcomp"); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 72185698c9bdcbf2bebed7ee82bc4ed082ce6a14..19b5734825df833fd34d634e4c1630dd75e96c4c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -23,23 +23,14 @@ limitations under the License. namespace xla { -class HloDomainIsolator::RunContext { - public: - RunContext(HloModule* module, HloDomainIsolator* isolator) - : module_(module), isolator_(isolator) {} +namespace { - StatusOr Run(); - - private: - HloModule* module_; - HloDomainIsolator* isolator_; -}; - -StatusOr HloDomainIsolator::RunContext::Run() { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); +StatusOr RunInternal(HloModule* module, + HloDomainIsolator::DomainCreator* creator) { + hlo_graph_dumper::MaybeDumpHloModule(*module, "Before Domain Isolator"); int64 added_domains = 0; - for (HloComputation* computation : module_->computations()) { + for (HloComputation* computation : module->computations()) { // Walk in post order and place all the required kDomain instructions. for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { @@ -55,8 +46,7 @@ StatusOr HloDomainIsolator::RunContext::Run() { root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - HloInstruction* domain = - isolator_->creator_(instruction, root, operand); + HloInstruction* domain = (*creator)(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); @@ -67,17 +57,19 @@ StatusOr HloDomainIsolator::RunContext::Run() { } VLOG(3) << "Added " << added_domains << " kDomain instructions"; if (added_domains > 0) { - hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + hlo_graph_dumper::MaybeDumpHloModule(*module, "After Domain Isolator"); } return added_domains > 0; } -HloDomainIsolator::HloDomainIsolator(DomainCreator creator) - : creator_(std::move(creator)) {} +} // namespace + +HloDomainIsolator::HloDomainIsolator(DomainCreatorFactory creator_factory) + : creator_factory_(std::move(creator_factory)) {} StatusOr HloDomainIsolator::Run(HloModule* module) { - RunContext run_context(module, this); - return run_context.Run(); + DomainCreator creator = creator_factory_(); + return RunInternal(module, &creator); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index c0bf1b9e16b52d81365db277abeb06defeb12d44..2274c3a96c2bdd1f4dbd454782699ccb0404529d 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -40,17 +40,15 @@ class HloDomainIsolator : public HloModulePass { // Returns nullptr in case no domain separation is necessary. using DomainCreator = std::function; - - explicit HloDomainIsolator(DomainCreator creator); + using DomainCreatorFactory = std::function; + explicit HloDomainIsolator(DomainCreatorFactory creator_factory_); absl::string_view name() const override { return "domain_isolator"; } StatusOr Run(HloModule* module) override; private: - class RunContext; - - DomainCreator creator_; + DomainCreatorFactory creator_factory_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 6ca1255edec377cf0738a1ad2596cb06aa1c2c6f..c6d02f9f67bb599e496d20fc2acf2e627ed54438 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -42,18 +42,19 @@ namespace xla { return std::move(domain_map); } -bool HloDomainMap::InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const { +bool HloDomainMap::InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const { int64 domain_id1 = GetDomainId(instruction1); int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } -int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } -int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainMetadataId( + const HloInstruction* instruction) const { return FindOrDie(domain_metadata_id_, instruction); } @@ -200,7 +201,8 @@ StatusOr> HloDomainMap::CreateDomain( return std::move(domain); } -bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { +bool HloDomainMap::IsDomainInstruction( + const HloInstruction* instruction) const { if (instruction->opcode() != HloOpcode::kDomain) { return false; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index c8d581b74677674ed8682ecc1fa022cea890a649..bce7d1aa7cf1822ef1608674e7bf9483c628e4b5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -58,21 +58,21 @@ class HloDomainMap { } // Checks whether two instructions are within the same domain. - bool InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const; + bool InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const; // Checks whether instruction is a kDomain instruction of the kind we are // currently processing. - bool IsDomainInstruction(HloInstruction* instruction) const; + bool IsDomainInstruction(const HloInstruction* instruction) const; // Retrieves the domain identifier of the instruction, or -1 in case // instruction is not found within any domain. - int64 GetDomainId(HloInstruction* instruction) const; + int64 GetDomainId(const HloInstruction* instruction) const; // Returns the unique id of the domain metadata for the domain the given // instruction belongs to. The given instruction must not be a kDomain // instruction since each domain instruction is associated with 2 domains. - int64 GetDomainMetadataId(HloInstruction* instruction) const; + int64 GetDomainMetadataId(const HloInstruction* instruction) const; private: // Map used for representing instruction ordering, i.e. @@ -119,8 +119,8 @@ class HloDomainMap { string domain_kind_; std::vector> instruction_domains_; - absl::flat_hash_map instruction_to_domain_; - absl::flat_hash_map domain_metadata_id_; + absl::flat_hash_map instruction_to_domain_; + absl::flat_hash_map domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 43e74d2f6f07bd685ad8683401138a4f06cd2ad2..acdb42128e3d9a1fb912a466c9c2c3cbbe3d3f83 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" @@ -22,13 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloVerifiedTestBase { +class HloDomainTest : public HloTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,13 +63,6 @@ class HloDomainTest : public HloVerifiedTestBase { } return false; } - - StatusOr ParseModule(absl::string_view hlo_string) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - ParseAndVerifyModule(hlo_string, config); - return &module(); - } }; // Dummy DomainMetadata implementation which create kDomain boundaries around @@ -106,20 +98,22 @@ class OpNameMetadata : public DomainMetadata { }; // Creator function for OpNameMetadata domains. -HloInstruction* OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand) { - if (instruction->metadata().op_name() == root->metadata().op_name()) { - return nullptr; +class OpNameDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + absl::make_unique(root->metadata().op_name()); + std::unique_ptr user_side_metadata = + absl::make_unique(instruction->metadata().op_name()); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } - std::unique_ptr operand_side_metadata = - absl::make_unique(root->metadata().op_name()); - std::unique_ptr user_side_metadata = - absl::make_unique(instruction->metadata().op_name()); - return operand->parent()->AddInstruction(HloInstruction::CreateDomain( - operand->shape(), operand, std::move(operand_side_metadata), - std::move(user_side_metadata))); -} +}; Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { @@ -142,31 +136,32 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -184,11 +179,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(!isolator_changed); } @@ -211,26 +207,27 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "b", "a")); - EXPECT_TRUE(HasDomainEdge(module, "f", "e_element")); - EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "b", "a")); - EXPECT_FALSE(HasDomainEdge(module, "f", "e_element")); + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e_element")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -248,11 +245,12 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_FALSE(isolator_changed); } @@ -271,15 +269,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module, "c"); + HloInstruction* add = FindInstruction(module.get(), "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -302,41 +301,42 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); + HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module)); + sharding_isolator.Run(module.get())); EXPECT_TRUE(sharding_isolator_changed); - HloDomainIsolator opname_isolator(OpNameDomainCreator); + HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -357,16 +357,17 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed")); - EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module.get(), "infeed.data", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. @@ -382,7 +383,7 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + HloInstruction* infeed_data = FindInstruction(module.get(), "infeed.data"); ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); @@ -408,7 +409,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); struct Assignment { @@ -444,25 +445,26 @@ ENTRY entry { sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "tuple", "param")); - EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple")); + EXPECT_TRUE(HasDomainEdge(module.get(), "tuple", "param")); + EXPECT_FALSE(HasDomainEdge(module.get(), "gte", "tuple")); // Remove %tuple and %gte (tuple simplification) - HloInstruction* gte = FindInstruction(module, "gte"); - HloInstruction* tuple = FindInstruction(module, "tuple"); + HloInstruction* gte = FindInstruction(module.get(), "gte"); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); module->entry_computation()->set_root_instruction(tuple->mutable_operand(0)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte)); TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple)); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); const HloInstruction* root = module->entry_computation()->root_instruction(); @@ -484,11 +486,11 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); auto hlo_string = module->ToString(); - ASSERT_TRUE(ParseModule(hlo_string).status().ok()); + ASSERT_TRUE(ParseAndReturnVerifiedModule(hlo_string).status().ok()); } // Tuple inputs are domain instructions. @@ -505,20 +507,21 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tpl instruction, in order to test domain sharding // application. - auto tpl = FindInstruction(module, "tpl"); + auto tpl = FindInstruction(module.get(), "tpl"); tpl->clear_sharding(); HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_EQ(HloSharding::Tuple(tpl->shape(), {HloSharding::AssignDevice(1), @@ -553,36 +556,37 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) })"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator opname_isolator(OpNameDomainCreator); + HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; }); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module)); + opname_isolator.Run(module.get())); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module, "c", "a")); - EXPECT_TRUE(HasDomainEdge(module, "c", "b")); - EXPECT_TRUE(HasDomainEdge(module, "d", "a")); - EXPECT_TRUE(HasDomainEdge(module, "d", "c")); - EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module)); + sharding_remover.Run(module.get())); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module)); + opname_remover.Run(module.get())); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module, "c", "a")); - EXPECT_FALSE(HasDomainEdge(module, "c", "b")); - EXPECT_FALSE(HasDomainEdge(module, "d", "a")); - EXPECT_FALSE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); } // Emulate instructions inserted at top and bottom within nested tuple domain. @@ -601,15 +605,16 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); - HloDomainIsolator isolator(ShardingDomainCreator{}); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; }); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); EXPECT_TRUE(isolator_changed); // Clear sharding of tuple.0 instruction, in order to test domain sharding // application. - auto tuple0 = FindInstruction(module, "tuple.0"); + auto tuple0 = FindInstruction(module.get(), "tuple.0"); tuple0->clear_sharding(); // Insert the following instructons above and below tuple.0, to emulate other @@ -653,7 +658,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), ShardingMetadata::NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); EXPECT_TRUE(remover_changed); EXPECT_TRUE(tuple0->has_sharding()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index eec8d242faaa70e84ab5b46904b0a0ea41d5b009..3a7652a8dc856b23c8988c4676916c8199e78860 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" @@ -38,11 +39,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -189,6 +190,11 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) return Unimplemented( "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); + typed_visitors_[TOKEN] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TOKEN."); + }); } template @@ -391,6 +397,16 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::HandleBitcast(HloInstruction* bitcast) { + const Literal& operand_literal = GetEvaluatedLiteralFor(bitcast->operand(0)); + Literal result(bitcast->shape()); + TF_RET_CHECK(operand_literal.size_bytes() == result.size_bytes()); + memcpy(result.untyped_data(), operand_literal.untyped_data(), + operand_literal.size_bytes()); + evaluated_[bitcast] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -1041,8 +1057,15 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -Status HloEvaluator::HandleAfterAll(HloInstruction* token) { - evaluated_[token] = LiteralUtil::CreateToken(); +Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) { + evaluated_[after_all] = LiteralUtil::CreateToken(); + return Status::OK(); +} + +Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) { + // AddDedendency just forwards its zero-th operand. + evaluated_[add_dependency] = + GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone(); return Status::OK(); } @@ -1228,7 +1251,7 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; + TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort"; // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the @@ -1274,12 +1297,14 @@ StatusOr EvaluateSortInternal(HloInstruction* sort, key_value_vector.push_back( std::make_pair(keys_data[i], values_data[i])); } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess(a.first, b.first); - }); + std::stable_sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess(a.first, b.first); + }); std::vector result_keys; - std::vector result_values; + // We use a InlinedVector here because we need to convert it to an + // absl::Span later, and this would not work with std::vector. + absl::InlinedVector result_values; for (const auto& key_value : key_value_vector) { result_keys.push_back(key_value.first); result_values.push_back(key_value.second); @@ -1315,7 +1340,10 @@ template StatusOr EvaluateSortCurried(HloInstruction* sort, const Literal& keys_literal, const Literal& values_literal) { - switch (sort->operand(1)->shape().element_type()) { + switch (values_literal.shape().element_type()) { + case PRED: + return EvaluateSortInternal(sort, keys_literal, + values_literal); case F32: return EvaluateSortInternal(sort, keys_literal, values_literal); @@ -1355,14 +1383,24 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { - auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), - GetEvaluatedLiteralFor(sort->operand(1))); - if (result.ok()) { - evaluated_[sort] = std::move(result.ValueOrDie()); - return Status::OK(); - } else { - return result.status(); + // This is a really stupid work-around for the fact it's hard to support a + // multi-value sort directly, due to the fact we need to template the + // evaluation function on all of the value types. + std::vector sort_results_backing; + for (int64 i = 0; i < sort->operand_count(); ++i) { + auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)), + GetEvaluatedLiteralFor(sort->operand(i))); + if (!result.ok()) { + return result.status(); + } + sort_results_backing.push_back( + std::move(result.ValueOrDie().DecomposeTuple()[1])); } + std::vector sort_results; + absl::c_transform(sort_results_backing, std::back_inserter(sort_results), + [](const Literal& literal) { return &literal; }); + evaluated_[sort] = LiteralUtil::MakeTuple(sort_results); + return Status::OK(); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 07f8d0aad4af0b07303b4e485b3630cc75bcb519..45ed8131dc6b71f706fce45d65b206363dd79ac3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -144,6 +144,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Operations that are type-agnostic or always return a specific type, such as // HandleIsFinite where boolean is always returned. // + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleParameter(HloInstruction* parameter) override; Status HandleConstant(HloInstruction* constant) override; @@ -180,7 +182,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleAfterAll(HloInstruction* token) override; + Status HandleAfterAll(HloInstruction* after_all) override; + + Status HandleAddDependency(HloInstruction* add_dependency) override; Status HandleSort(HloInstruction* sort) override; @@ -221,16 +225,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const Literal& operand_literal) { const auto shape = instruction->shape(); const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape), - ShapeUtil::HumanString(operand->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); Literal result(shape); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index cee11a8a2166f96ae801095b6364921ed05d0000..4eaaab20ea0add17d9b49b1b2b97991af0438dcc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -33,8 +34,9 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -50,9 +52,9 @@ namespace { static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, - public HloVerifiedTestBase { + public HloTestBase { protected: - HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -60,14 +62,14 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); - type_converter.Run(&module()).ValueOrDie(); + type_converter.Run(m_.get()).ValueOrDie(); } - return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + return evaluator_->Evaluate(*m_->entry_computation(), arg_literals) .ConsumeValueOrDie(); } - // Evaluate function that takes in a local module instead of using module_ - // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // Evaluate function that takes in a local module instead of using m_ + // that is in HloTestBase. Once m_ in HloTestBase is // removed, this should be the default Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { @@ -88,7 +90,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -108,7 +110,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -116,6 +118,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, } bool use_bfloat16_; + std::unique_ptr m_ = CreateNewVerifiedModule(); }; #define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ @@ -135,7 +138,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -156,7 +159,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -181,7 +184,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -322,7 +325,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(args); @@ -346,7 +349,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -367,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, {1, 2})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -386,7 +389,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate({}); @@ -406,7 +409,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -428,7 +431,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -448,7 +451,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -468,7 +471,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -503,7 +506,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -530,7 +533,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -574,7 +577,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -619,7 +622,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -658,7 +661,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -704,7 +707,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -748,7 +751,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, rhs_instruction, dot_dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -802,7 +805,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -857,7 +860,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -941,7 +944,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1019,7 +1022,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1079,7 +1082,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1143,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1215,7 +1218,7 @@ TEST_P(HloEvaluatorTest, b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1286,7 +1289,7 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1297,11 +1300,12 @@ TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder b(TestName()); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1319,12 +1323,12 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m->AddEmbeddedComputation(add_computation.Build()); HloInstruction* reduce_instruction = b.AddInstruction( HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{0}, add_func)); - module().AddEntryComputation(b.Build()); + m->AddEntryComputation(b.Build()); HloEvaluator hlo_eval; Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); @@ -1337,7 +1341,7 @@ void BM_ReducePrecisely(int num_iters) { tensorflow::testing::StopTiming(); HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 @@ -1396,14 +1400,14 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1438,7 +1442,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1455,7 +1459,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1463,6 +1467,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = m_->AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + m_->AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); @@ -1489,7 +1545,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1512,7 +1568,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1542,7 +1598,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + auto add_func = m_->AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1573,7 +1629,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1605,7 +1661,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1639,7 +1695,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1675,7 +1731,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1712,7 +1768,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1748,7 +1804,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1787,7 +1843,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1825,7 +1881,7 @@ TEST_P(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - module().AddEntryComputation(b.Build()); + m_->AddEntryComputation(b.Build()); Literal result = Evaluate(); @@ -1914,7 +1970,7 @@ ENTRY main { slice_sizes={1, 3} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -1938,7 +1994,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); @@ -1962,7 +2018,7 @@ ENTRY main { slice_sizes={3, 1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -1987,7 +2043,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2014,7 +2070,7 @@ ENTRY main { slice_sizes={1,1,2} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2040,7 +2096,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR1({1, 1}); @@ -2063,7 +2119,7 @@ ENTRY main { slice_sizes={1,1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2087,7 +2143,7 @@ ENTRY main { slice_sizes={1, 0} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), @@ -2109,7 +2165,7 @@ ENTRY main { slice_sizes={1} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal start_indices = @@ -2140,7 +2196,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2171,7 +2227,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2204,7 +2260,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2236,7 +2292,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); @@ -2268,7 +2324,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); @@ -2302,7 +2358,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2334,7 +2390,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); @@ -2366,7 +2422,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2403,7 +2459,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // @@ -2439,7 +2495,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); @@ -2471,7 +2527,7 @@ ENTRY main { index_vector_dim=0 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); @@ -2503,7 +2559,7 @@ ENTRY main { index_vector_dim=1 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); Literal updates = LiteralUtil::CreateR2({{}, {}}); @@ -2533,7 +2589,7 @@ ENTRY main { index_vector_dim=2 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal operand = LiteralUtil::CreateR1({0, 1, 2}); Literal scatter_indices = @@ -2684,7 +2740,7 @@ ENTRY main { ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); @@ -2702,7 +2758,7 @@ ENTRY main { ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} } )"; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); Literal arg = LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, @@ -2711,6 +2767,33 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } +TEST_P(HloEvaluatorTest, Bitcast) { + // Regression test for b/114735354. + constexpr absl::string_view hlo_text_base = R"( +HloModule Bitcast + +ENTRY main { + param = %s[32,121]{1,0} parameter(0) + ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param) +} +)"; + string hlo_text; + if (use_bfloat16_) { + hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16"); + } else { + hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32"); + } + TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text)); + auto args = MakeFakeArguments(m_.get()).ConsumeValueOrDie(); + Literal actual = Evaluate({&args[0]}); + if (use_bfloat16_) { + EXPECT_TRUE( + absl::c_equal(args[0].data(), actual.data())); + } else { + EXPECT_TRUE(absl::c_equal(args[0].data(), actual.data())); + } +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b2d12c94b848e4fd8ae473fdc0e4a9f5fecf6286..b87fc3e34012e75ee07bff6c1e113dce404f83cb 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/types/optional.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/core/lib/core/casts.h" namespace xla { @@ -161,9 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { HloOpcodeString(hlo_instruction->opcode())); } - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive type. - template ::value>::type* = nullptr> @@ -596,7 +593,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) { + Status HandleDivide(HloInstruction* divide) override { return HandleDivide(divide); } @@ -1072,66 +1069,66 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { + // Find corresponding spatial dimension index for input (lhs). + int64 lhs_linear_spatial_index = 0; + int64 rhs_linear_spatial_index = 0; + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + lhs_linear_spatial_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + rhs_linear_spatial_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { const int64 iz = feature_group_index * input_feature_group_size + rhs_iz; - int64 lhs_linear_index = 0; + int64 lhs_linear_index = lhs_linear_spatial_index; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - int64 rhs_linear_index = 0; + int64 rhs_linear_index = rhs_linear_spatial_index; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - result_val += static_cast(lhs_literal_data[lhs_linear_index]) * static_cast(rhs_literal_data[rhs_linear_index]); @@ -1556,10 +1553,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto& row_data = row_to_sort.data(); std::vector result_data(row_data.begin(), row_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const NativeT& a, const NativeT& b) { - return SafeLess(a, b); - }); + std::stable_sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess(a, b); + }); Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), {sort_dim_elements})); sorted_row.PopulateR1(absl::Span(result_data)); @@ -2442,7 +2439,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reduce_precision], ElementWiseUnaryOp(reduce_precision, [reduce_precision]( ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast(elem); + uint32_t value_as_int = absl::bit_cast(elem); const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); const uint32_t exponent_bits = reduce_precision->exponent_bits(); @@ -2515,7 +2512,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { value_as_int = x_underflows ? x_signed_zero : value_as_int; } - float reduced_result = tensorflow::bit_cast(value_as_int); + float reduced_result = absl::bit_cast(value_as_int); if (std::isnan(elem)) { reduced_result = mantissa_bits > 0 ? elem @@ -2546,12 +2543,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template ::value || - std::is_same::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value || + std::is_floating_point::value>::type* = nullptr> Status HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); - std::vector data(iota->shape().dimensions(iota->iota_dimension())); + // Avoid using std::vector since std::vector does not convert to + // absl::Span. + absl::InlinedVector data( + iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); auto result = LiteralUtil::CreateR1(data); @@ -2568,9 +2567,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template ::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + !(std::is_integral::value || + std::is_floating_point::value)>::type* = nullptr> Status HandleIota(HloInstruction* iota) { return InvalidArgument("Unsupported type for iota"); } @@ -2613,8 +2611,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; @@ -2713,17 +2720,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); @@ -2747,19 +2745,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), - ShapeUtil::HumanString(rhs->shape()), - ShapeUtil::HumanString(ehs->shape())); - } + TF_RET_CHECK(ShapeUtil::SameDimensions(shape, lhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(lhs->shape(), rhs->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(rhs->shape(), ehs->shape())); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ce4cad42355ec5881f2ae14f4dd52a0588d51cf7..2df8eb962ae54eb5b9492fdeb274eec309a8288f 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -28,7 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module, + absl::Span extra_metrics) { size_t current_profile_index = 0; for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { InsertOrDie(&computation_to_profile_idx_, computation, @@ -40,11 +41,15 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { current_profile_index++); } } + for (const string& key : extra_metrics) { + InsertOrDie(&extra_metric_to_profile_idx_, key, current_profile_index++); + } } std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, - const HloCostAnalysis& cost_analysis) { + const HloCostAnalysis& cost_analysis, + const string& entry_computation_name) { using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; @@ -105,6 +110,14 @@ std::unique_ptr CreateHloProfilePrinterData( } } + // Add extra metrics if any. + for (const auto& pair : hlo_profile_index_map.extra_metric_to_profile_idx()) { + profile_printer_data->mutable_extra_metrics()->insert( + {pair.first, pair.second}); + } + + profile_printer_data->set_entry_computation(entry_computation_name); + return profile_printer_data; } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index be989846ef5cd2645da88ac9bbfea9534dd47821..da30e15908328f9aa7fe282656a6d44ab7348195 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_ #include +#include #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" @@ -34,7 +35,10 @@ class HloInstruction; class HloProfileIndexMap { public: // Scans `module` to populate this instance of HloProfileIndexMap. - explicit HloProfileIndexMap(const HloModule& module); + explicit HloProfileIndexMap(const HloModule& module) + : HloProfileIndexMap(module, {}) {} + explicit HloProfileIndexMap(const HloModule& module, + absl::Span extra_metrics); HloProfileIndexMap(const HloProfileIndexMap&) = default; HloProfileIndexMap(HloProfileIndexMap&&) = default; @@ -50,6 +54,10 @@ class HloProfileIndexMap { return FindOrDie(computation_to_profile_idx(), &computation); } + size_t GetProfileIndexFor(const string& key) const { + return xla::FindOrDie(extra_metric_to_profile_idx(), key); + } + size_t instruction_count() const { return instruction_to_profile_idx().size(); } @@ -58,8 +66,12 @@ class HloProfileIndexMap { return computation_to_profile_idx().size(); } + size_t extra_metrics_count() const { + return extra_metric_to_profile_idx().size(); + } + size_t total_count() const { - return instruction_count() + computation_count(); + return instruction_count() + computation_count() + extra_metrics_count(); } const std::unordered_map& @@ -72,15 +84,20 @@ class HloProfileIndexMap { return computation_to_profile_idx_; } + const std::unordered_map& extra_metric_to_profile_idx() const { + return extra_metric_to_profile_idx_; + } + private: std::unordered_map instruction_to_profile_idx_; std::unordered_map computation_to_profile_idx_; + std::unordered_map extra_metric_to_profile_idx_; }; // Create an instance of `HloProfilePrinterData`. std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, - const HloCostAnalysis& cost_analysis); + const HloCostAnalysis& cost_analysis, const string& entry_computation_name); // Describes how much time each HLO operation took. // @@ -113,6 +130,12 @@ class HloExecutionProfile { total_cycles_executed; } + // Record extra metric. + void set_extra_metrics(const string& metric, uint64 value) { + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(metric)] = + value; + } + // Returns a version of the execution profile suitable for performance // debugging; e.g. emits cycle counts, execution time at the nominal device // frequency, and the effective throughput given the provided cost_analysis diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 460ae2b5eca78659f86df1227e6a0a4e57508611..5be9dba3aa49d63c580cd6f5800f608667826b6a 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -54,7 +54,8 @@ TEST_F(HloExecutionProfileTest, Basic) { HloCostAnalysis cost_analysis(shape_size_function); HloProfileIndexMap profile_index_map(*hlo_module); std::unique_ptr profile_printer = - CreateHloProfilePrinterData(profile_index_map, cost_analysis); + CreateHloProfilePrinterData(profile_index_map, cost_analysis, + hlo_module->entry_computation()->name()); HloExecutionProfile execution_profile(profile_printer.get(), &profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c919dbd82d3668c477bf37074f1d56f8cb7d9506 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + +namespace xla { + +namespace { + +StatusOr ReplaceGetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32)); + uint32 size = instr->operand(0)->shape().dimensions(instr->dimension()); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + return true; +} + +} // namespace + +StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { + bool changed = false; + HloProto proto; + *proto.mutable_hlo_module() = module->ToProto(); + for (auto* computation : module->computations()) { + for (auto instruction : computation->instructions()) { + TF_ASSIGN_OR_RETURN(bool replaced, ReplaceGetSize(instruction)); + changed = changed || replaced; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..30f44c23a835b3bcc935caaa917e040e07c4e703 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass to replace a kGetDimensionSize instruction with a constant instruction. +class HloGetDimensionSizeRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "hlo-get-dimension-size-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GET_DIMENSION_SIZE_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86aebdd5b64240e6e07d8e8050c0c8681cce765 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloGetDimensionSizeRewriterTest : public HloTestBase { + protected: + HloGetDimensionSizeRewriterTest() {} +}; + +TEST_F(HloGetDimensionSizeRewriterTest, Ok) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = u32[] get-dimension-size(p), dimensions={0} + size1 = u32[] get-dimension-size(p), dimensions={1} + ROOT mul = u32[] multiply(size0, size1) +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { + auto module = ParseHloString(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = u32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + HloGetDimensionSizeRewriter pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 13a74fd8a115c5dc9a9518b226dfee4445cc7180..302eca656be53a3cec86ddbf05a7fa3925c5185b 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -111,11 +113,6 @@ class NodeFilter { result == kSomeUsersOmitted; } - bool ShowFusionSubcomputation(const HloInstruction* instr) const { - CHECK_EQ(instr->opcode(), HloOpcode::kFusion); - return Show(instr) && !SomeOrAllOperandsOmitted(instr); - } - private: std::function filter_; }; @@ -240,34 +237,28 @@ string HtmlLikeStringSanitize(absl::string_view s) { // it to a short string lets us tell the user what the subcomputation is without // drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { + namespace m = match; + if (computation->instruction_count() != 3) { return nullopt; } - HloInstruction* root = computation->root_instruction(); - if (root->operand_count() != 2) { - return nullopt; - } - - // Check that both of the operands to the root are parameters. - const HloInstruction* operand0 = root->operand(0); - const HloInstruction* operand1 = root->operand(1); - if (operand0->opcode() != HloOpcode::kParameter || - operand1->opcode() != HloOpcode::kParameter) { - return nullopt; - } - - // Check that the two operands of root are param0 and param1. All of the - // opcodes we recognize are commutative, so we're OK with either order. - auto n0 = operand0->parameter_number(); - auto n1 = operand1->parameter_number(); - if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { + const HloInstruction *param0, *param1; + if (!Match(root, m::Op() + .WithNumOperands(2) + .WithShape(m::Shape().IsEffectiveScalar()) + .WithBinaryOperandsAnyOrder( + m::Parameter(¶m0, 0) + .WithShape(m::Shape().IsEffectiveScalar()), + m::Parameter(¶m1, 1) + .WithShape(m::Shape().IsEffectiveScalar())))) { return nullopt; } - // If the params are reversed, check that the operation being performed is - // commutative. - if (n0 == 1) { + // If the params are reversed (i.e. operand0 is param1 and operand1 is + // param0), check that the operation being performed is commutative. + if (root->operand(0) == param1) { + CHECK_EQ(root->operand(1), param0); switch (root->opcode()) { case HloOpcode::kLe: case HloOpcode::kGe: @@ -279,13 +270,6 @@ optional MatchTrivialComputation(const HloComputation* computation) { } } - // Check that the root and params are all effective scalars. - if (!ShapeUtil::IsEffectiveScalar(root->shape()) || - !ShapeUtil::IsEffectiveScalar(operand0->shape()) || - !ShapeUtil::IsEffectiveScalar(operand1->shape())) { - return nullopt; - } - // If we recognize the root's opcode, we've successfully pattern-matched! switch (root->opcode()) { case HloOpcode::kAdd: @@ -578,7 +562,7 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { // Show the subcomputation if we're showing any of its members. return std::any_of( - computation_->instructions().begin(), computation_->instructions().end(), + subcomp->instructions().begin(), subcomp->instructions().end(), [&](const HloInstruction* instr) { return filter_.Show(instr); }); } @@ -987,6 +971,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -1043,6 +1028,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGetDimensionSize: return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: @@ -1266,12 +1252,12 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( class GraphRendererRegistry { public: - void AddRenderer(GraphRendererInterface* graph_renderer) { + void SetRenderer(std::shared_ptr graph_renderer) { tensorflow::mutex_lock lock(mu_); graph_renderer_ = graph_renderer; } - GraphRendererInterface* GetDefaultRenderer() { + std::shared_ptr GetDefaultRenderer() { tensorflow::mutex_lock lock(mu_); return graph_renderer_; } @@ -1283,20 +1269,21 @@ class GraphRendererRegistry { private: tensorflow::mutex mu_; - GraphRendererInterface* graph_renderer_ = nullptr; + std::shared_ptr graph_renderer_ GUARDED_BY(mu_); }; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper) { - GraphRendererRegistry::Default()->AddRenderer(dumper); +Registrar::Registrar(std::shared_ptr dumper) { + GraphRendererRegistry::Default()->SetRenderer(dumper); } namespace { // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. -NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { +NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root, + int64 radius) { // First, find the neighborhood of nodes with distance from root <= radius. // These nodes are our initial set of "normal" nodes. std::unordered_map nodes; @@ -1403,6 +1390,56 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } +// Gets a node filter that includes nodes on all paths from `from` to `to`. If +// the all-paths set contains more than max_nodes elements, includes the nodes +// on the shortest paths and sets hit_limit to true. +NodeFilter MakeNodeFromToFilter(const HloInstruction* from, + const HloInstruction* to, int64 max_nodes, + bool* hit_limit) { + *hit_limit = false; + + // Elements in the queue are paths through the graph. + std::deque> queue; + queue.push_front({from}); + + // Compute the set of nodes we want to show using a slightly-modified + // Djikstra's algorithm. The only real difference is, rather than stopping + // when we find a (shortest) path, we continue until we've found max_nodes + // nodes on some path. + std::unordered_set visited; + std::unordered_set to_display = {from, to}; + while (!queue.empty() && to_display.size() < max_nodes) { + std::vector path = std::move(queue.front()); + queue.pop_front(); + if (!visited.insert(path.back()).second) { + continue; + } + + for (const auto* user : path.back()->users()) { + if (user == to) { + auto it = path.begin(); + for (; it != path.end() && to_display.size() < max_nodes; ++it) { + to_display.insert(*it); + } + if (it != path.end()) { + *hit_limit = true; + } + } else if (!visited.count(user)) { + auto new_path = path; + new_path.push_back(user); + queue.push_back(std::move(new_path)); + } + } + } + + return NodeFilter([=](const HloInstruction* instr) { + if (instr == from || instr == to) { + return kHighlightNode; + } + return to_display.count(instr) ? kNormalNode : kHideNode; + }); +} + string SaveGraph(const string& graph, GraphRendererInterface::GraphKind graph_kind, const string& dest_path) { @@ -1482,7 +1519,7 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); - NodeFilter filter = MakeNodeFilter(&node, radius); + NodeFilter filter = MakeNodeRadiusAroundFilter(&node, radius); string graph = HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) @@ -1490,6 +1527,29 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config) { + CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!"; + auto debug_options = from.GetModule()->config().debug_options(); + + bool hit_limit = false; + NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit); + string label; + if (!hit_limit) { + label = StrCat("All paths from ", from.name(), " to ", to.name()); + } else { + label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(), + " to ", to.name(), + "

***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN " + "NODES***

"); + } + string graph = + HloDotDumper(from.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); +} + void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7..de1eefab776f9c3d2c73959a5cd267e938a78a32 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -66,6 +66,12 @@ string DumpGraph(const HloComputation& computation, const string& label, string DumpNeighborhoodAround(const HloInstruction& node, int radius, bool show_backend_config = false); +// Dumps nodes on any of the paths from `from` to `to`. If there are more than +// max_nodes on all paths, restricts to the max_nodes nodes on the shortest +// paths. +string DumpAllPathsFromTo(const HloInstruction& from, const HloInstruction& to, + int64 max_nodes, bool show_backend_config = false); + // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. // @@ -87,13 +93,13 @@ void DumpText(const HloModule& module, const string& label, // Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper); + Registrar(std::shared_ptr dumper); }; -#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ - static ::xla::hlo_graph_dumper::Registrar \ - XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)(new factory, \ - ##__VA_ARGS__) +#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ + static ::xla::hlo_graph_dumper::Registrar \ + XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)( \ + std::make_shared(), ##__VA_ARGS__) // __COUNTER__ must go through another macro to be properly expanded #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e1597fd03db0a78aa560340b7b9b64fe500df0c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -0,0 +1,207 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) { + TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) + << absl::StrCat("Tring to set up alias at ", output_index.ToString(), + " which is an invalid index for shape ", + ShapeUtil::HumanString(alias_.shape())); + // Output can't be aliased with multiple parameters. + TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat( + "Trying to set up output alias for param %lld at %s but failed: output " + "index %s is already aliased with param %lld at %s", + param_number, param_index.ToString(), output_index.ToString(), + alias_.element(output_index)->first, + alias_.element(output_index)->second.ToString()); + (*alias_.mutable_element(output_index)) = + std::make_pair(param_number, param_index); + VLOG(4) << "Set up alias between output index " << output_index.ToString() + << " and parameter " << param_index << " at index " + << param_index.ToString(); + return Status::OK(); +} + +HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { + HloInputOutputAliasProto result; + alias_.ForEachElement( + [&](const ShapeIndex& index, + const absl::optional>& data) { + if (data) { + HloInputOutputAliasProto::AliasEntryProto entry; + for (int64 i : index) { + entry.add_output_shape_index(i); + } + entry.set_parameter_number(data->first); + for (int64 i : data->second) { + entry.add_parameter_shape_index(i); + } + result.add_entries()->Swap(&entry); + } + }); + return result; +} + +StatusOr HloInputOutputAliasConfig::CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto) { + HloInputOutputAliasConfig result(output_shape); + for (const HloInputOutputAliasProto::AliasEntryProto& entry : + proto.entries()) { + ShapeIndex output_index(entry.output_shape_index().begin(), + entry.output_shape_index().end()); + + int64 param_number = entry.parameter_number(); + ShapeIndex param_index(entry.parameter_shape_index().begin(), + entry.parameter_shape_index().end()); + TF_RETURN_IF_ERROR( + result.SetUpAlias(output_index, param_number, param_index)); + } + + return result; +} + +string HloInputOutputAliasConfig::ToString() const { + std::vector pieces; + pieces.push_back("HloInputOutputAliasConfig"); + + ForEachAlias([&](const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index) { + pieces.push_back(absl::StrFormat( + " OutputIndex %s is aliased with parameter %lld at %s:", + output_index.ToString(), param_number, param_index.ToString())); + }); + + return absl::StrJoin(pieces, "\n"); +} + +bool HloInputOutputAliasConfig::ParameterHasAlias( + int64 param_number, const ShapeIndex& param_index) const { + bool output = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = true; + } + }); + return output; +} + +absl::optional HloInputOutputAliasConfig::GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const { + absl::optional output; + alias_.ForEachElement( + [&](const xla::ShapeIndex& output_index, + absl::optional> alias) { + if (alias && alias->first == param_number && + alias->second == param_index) { + output = output_index; + } + }); + return output; +} + +absl::optional> +HloInputOutputAliasConfig::GetAliasedParameter( + const ShapeIndex& output_index) const { + CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)); + return alias_.element(output_index); +} + +void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const { + alias_.ForEachElement( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + fn(output_index, aliased->first, aliased->second); + } + }); +} + +Status HloInputOutputAliasConfig::ForEachAliasWithStatus( + AliasFnWithStatus fn) const { + return alias_.ForEachElementWithStatus( + [&](const ShapeIndex& output_index, + absl::optional> aliased) { + if (aliased) { + TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second)); + } + return Status::OK(); + }); +} + +Status HloInputOutputAliasConfig::Verify( + const HloModule& module, + std::function size_func) const { + std::vector> param_has_seen; + const HloComputation* entry = module.entry_computation(); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + HloInstruction* param = entry->parameter_instruction(i); + param_has_seen.emplace_back(param->shape()); + } + return ForEachAliasWithStatus([&](const ShapeIndex& output_index, + int64 param_number, + const ShapeIndex& param_index) -> Status { + const HloInstruction* root = entry->root_instruction(); + + TF_RET_CHECK(0 <= param_number); + TF_RET_CHECK(entry->num_parameters() > param_number); + const Shape& param_shape = + entry->parameter_instruction(param_number)->shape(); + const Shape& output_shape = root->shape(); + TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index)); + TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index)); + + const Shape& param_subshape = + ShapeUtil::GetSubshape(param_shape, param_index); + const Shape& output_subshape = + ShapeUtil::GetSubshape(output_shape, output_index); + TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape)); + TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape)); + + if (size_func(param_subshape) != size_func(output_subshape)) { + return InternalError( + "Expected aliased input %lld at index %s and output at index %s to " + "have the same size. Input sub-shape is %s with size %lld, output " + "sub-shape is %s with size %lld", + param_number, param_index.ToString(), output_index.ToString(), + ShapeUtil::HumanStringWithLayout(param_subshape), + size_func(param_subshape), + ShapeUtil::HumanStringWithLayout(output_subshape), + size_func(output_subshape)); + } + + // Check each param_number and param_index pair only show up once. No + // input can be aliased with output buffers. + TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false); + + *(param_has_seen[param_number].mutable_element(param_index)) = true; + + return Status::OK(); + }); +} + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config) { + out << config.ToString(); + return out; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h new file mode 100644 index 0000000000000000000000000000000000000000..439676b1546c4af7f781fb80bccffd5248309b0f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +class HloModule; + +// This class specifies the alias map from output index to parameter number and +// parameter index in the entry computation. +class HloInputOutputAliasConfig { + public: + HloInputOutputAliasConfig() = default; + + explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {} + + virtual ~HloInputOutputAliasConfig() = default; + + // Sets up alias config from `output_index` to `param_index` at + // `param_number`. + Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index); + + // Returns true if the given parameter is aliased with one of the output + // buffers. + bool ParameterHasAlias(int64 param_number, + const ShapeIndex& param_index) const; + + // (De)Serializes an HloInputOutoutAliasConfig to/from an + // HloInputOutoutAliasProto. + HloInputOutputAliasProto ToProto() const; + + static StatusOr CreateFromProto( + const Shape& output_shape, const HloInputOutputAliasProto& proto); + + // Returns the output index that the given parameter and parameter index is + // aliased with. A nullopt is returned if there is no output that is aliased + // with the parameter number and index. + absl::optional GetAliasedOutput( + int64 param_number, const ShapeIndex& param_index) const; + + // Returns the number of parameter and index of the parameter buffer that the + // given output buffer index is aliased with. A nullopt is returned if there + // is no parameter is aliased with the specific output. + absl::optional> GetAliasedParameter( + const ShapeIndex& output_index) const; + + using AliasFn = + std::function; + + // Iterates through each aliased output and input. + void ForEachAlias(AliasFn fn) const; + + using AliasFnWithStatus = + std::function; + + // Verifies that the given config is valid for the given module. + // Specifically, the config's input and output should be in-bound and size of + // the aliased buffers should match. + Status Verify(const HloModule& module, + std::function size_func_) const; + + Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + + string ToString() const; + + private: + // A ShapeTree which indicates the list of buffers that's expected to be + // aliased. The key on this shape tree represents the output index. The value + // is a pair of parameter number and index into the buffer. If the value is + // nullopt, it means there is no parameter aliasing for this output. + ShapeTree>> alias_; +}; + +std::ostream& operator<<(std::ostream& out, + const HloInputOutputAliasConfig& config); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_ diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aeb9b0fdc8b6cca87731a2d4aae25120af6c3215 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { +class HloInputOutputAliasConfigTest : public HloTestBase { + protected: + void expect_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_TRUE(aliased_output); + EXPECT_EQ(aliased_output.value(), output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_TRUE(aliased_param); + EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index)); + } + + void expect_not_aliased(const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + const HloInputOutputAliasConfig& config) { + absl::optional aliased_output = + config.GetAliasedOutput(param_number, param_index); + + EXPECT_FALSE(aliased_output && aliased_output == output_index); + + absl::optional> aliased_param = + config.GetAliasedParameter(output_index); + + EXPECT_FALSE(aliased_param && aliased_param->first == param_number && + aliased_param->second == param_index); + } +}; + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + param = (f32[], f32[]) parameter(0) + gte1 = f32[] get-tuple-element(%param), index=0 + gte2 = f32[] get-tuple-element(%param), index=1 + ROOT root = (f32[], f32[]) tuple(%gte1, %gte2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1})); + + expect_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{0}, config); + + expect_aliased(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{1}, config); + + expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1, + /*param_index=*/{}, config); + + expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{}, config); +} + +TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + })); +} + +TEST_F(HloInputOutputAliasConfigTest, SizesMustMatch) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[4096] parameter(1) + ROOT root = (f32[], f32[4096]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + })); +} + +TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) { + const string module_str = R"( +HloModule TEST + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT root = (f32[], f32[]) tuple(%a, %b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + + HloInputOutputAliasConfig config( + module->entry_computation()->root_instruction()->shape()); + + TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0, + /*param_index=*/{})); + + ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1, + /*param_index=*/{})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2f6db7cd7c0ada166dc81f75c4a9989eb9d70638..21b1dbc1676cccd2fe5b331a1f9d6ff5e3a73fcd 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -93,7 +93,8 @@ StatusOr> HloInstruction::CreateFromProto( [&computation_map](int64 id) { return computation_map.contains(id); })) << proto.name() << " instruction references invalid computation id(s)"; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Shape shape(proto.shape()); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { // Ops migrated to subclasses. @@ -101,23 +102,23 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 3) << "BatchNormTraining instruction should have 3 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormTraining( - proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), - proto.feature_index()); + instruction = + CreateBatchNormTraining(shape, operands(0), operands(1), operands(2), + proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormInference instruction should have 5 operands but sees " << proto.operand_ids_size(); instruction = CreateBatchNormInference( - proto.shape(), operands(0), operands(1), operands(2), operands(3), + shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: TF_RET_CHECK(proto.operand_ids_size() == 5) << "BatchNormGrad instruction should have 5 operands but sees " << proto.operand_ids_size(); - instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + instruction = CreateBatchNormGrad(shape, operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; @@ -127,7 +128,7 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); - instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + instruction = CreateFft(shape, operands(0), proto.fft_type(), absl::Span(fft_length)); break; } @@ -148,7 +149,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + instruction = CreateRecv(shape.tuple_shapes(0), operands(0), proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: @@ -161,7 +162,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Reverse instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateReverse(proto.shape(), operands(0), + instruction = CreateReverse(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -170,7 +171,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Concatenate instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = - CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); + CreateConcatenate(shape, all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) @@ -188,24 +189,23 @@ StatusOr> HloInstruction::CreateFromProto( absl::MakeSpan(reduce_operands) .subspan(reduce_operands.size() / 2, reduce_operands.size()); instruction = - CreateReduce(proto.shape(), inputs, init_values, + CreateReduce(shape, inputs, init_values, std::vector(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + shape, proto.dimensions(0), keys, + absl::Span(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -213,7 +213,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Transpose instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateTranspose(proto.shape(), operands(0), + CreateTranspose(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -222,7 +222,7 @@ StatusOr> HloInstruction::CreateFromProto( << "Broadcast instruction should have 1 operand but sees " << proto.operand_ids_size(); instruction = - CreateBroadcast(proto.shape(), operands(0), + CreateBroadcast(shape, operands(0), std::vector(proto.dimensions().begin(), proto.dimensions().end())); break; @@ -230,7 +230,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Map instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateMap(proto.shape(), all_operands(), computations(0)); + instruction = CreateMap(shape, all_operands(), computations(0)); break; case HloOpcode::kSlice: { TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -243,8 +243,8 @@ StatusOr> HloInstruction::CreateFromProto( slice_limits.push_back(slice_dimensions.limit()); slice_strides.push_back(slice_dimensions.stride()); } - instruction = CreateSlice(proto.shape(), operands(0), slice_starts, - slice_limits, slice_strides); + instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits, + slice_strides); break; } case HloOpcode::kConstant: { @@ -254,7 +254,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = absl::make_unique(proto.shape()); + instruction = absl::make_unique(shape); } break; } @@ -285,44 +285,54 @@ StatusOr> HloInstruction::CreateFromProto( tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; - instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), - fused_computation); + instruction = + CreateFusion(shape, fusion_kind, all_operands(), fused_computation); break; } case HloOpcode::kRng: - instruction = - CreateRng(proto.shape(), proto.distribution(), all_operands()); + instruction = CreateRng(shape, proto.distribution(), all_operands()); break; case HloOpcode::kParameter: - instruction = CreateParameter(proto.parameter_number(), proto.shape(), - proto.name()); + instruction = + CreateParameter(proto.parameter_number(), shape, proto.name()); break; case HloOpcode::kGetTupleElement: TF_RET_CHECK(proto.operand_ids_size() == 1) << "GetTupleElement instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateGetTupleElement(proto.shape(), operands(0), - proto.tuple_index()); + instruction = + CreateGetTupleElement(shape, operands(0), proto.tuple_index()); break; case HloOpcode::kReducePrecision: - instruction = - CreateReducePrecision(proto.shape(), operands(0), - proto.exponent_bits(), proto.mantissa_bits()); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateReducePrecision( + shape, operands(0), proto.exponent_bits(), proto.mantissa_bits()); break; case HloOpcode::kInfeed: { - const Shape& data_shape = - ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(ShapeUtil::IsTuple(shape) && + (ShapeUtil::TupleElementCount(shape) == 2)) + << "Infeed should have a tuple shape with 2 operands, but has: " + << shape; + const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; - case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + case HloOpcode::kOutfeed: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); + Shape outfeed_shape(proto.outfeed_shape()); TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); + ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape)); + instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1), + proto.outfeed_config()); break; + } case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " @@ -332,7 +342,7 @@ StatusOr> HloInstruction::CreateFromProto( all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( - proto.shape(), all_operands(), computations(0), + shape, all_operands(), computations(0), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end()), @@ -342,21 +352,24 @@ StatusOr> HloInstruction::CreateFromProto( } case HloOpcode::kAllToAll: { instruction = CreateAllToAll( - proto.shape(), all_operands(), + shape, all_operands(), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), proto.replica_groups().end())); break; } case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(proto.shape(), operands(0), - source_target_pairs); + instruction = + CreateCollectivePermute(shape, operands(0), source_target_pairs); break; } case HloOpcode::kConvolution: { @@ -369,7 +382,7 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), + shape, operands(0), operands(1), std::max(proto.feature_group_count(), 1), proto.window(), proto.convolution_dimension_numbers(), precision_config); break; @@ -381,7 +394,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "ReduceWindow should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + instruction = CreateReduceWindow(shape, operands(0), operands(1), proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: @@ -391,14 +404,28 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 2) << "SelectAndScatter should have 2 called computations but sees " << proto.called_computation_ids_size(); - instruction = CreateSelectAndScatter( - proto.shape(), operands(0), computations(0), proto.window(), - operands(1), operands(2), computations(1)); + instruction = CreateSelectAndScatter(shape, operands(0), computations(0), + proto.window(), operands(1), + operands(2), computations(1)); break; case HloOpcode::kCustomCall: - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target(), - proto.custom_call_opaque()); + if (proto.constrain_layout()) { + // A proto RepeatedPtrField cannot be converted to a Span (it is a + // vector of pointers essentially) so create a vector of shapes to pass + // in. + std::vector operand_shapes; + for (const ShapeProto& shape_proto : + proto.operand_shapes_with_layout()) { + operand_shapes.emplace_back(shape_proto); + } + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); + } else { + instruction = + CreateCustomCall(shape, all_operands(), proto.custom_call_target(), + proto.custom_call_opaque()); + } if (proto.has_window()) { static_cast(instruction.get()) ->set_window(proto.window()); @@ -417,8 +444,8 @@ StatusOr> HloInstruction::CreateFromProto( << "Pad instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_padding_config()); - instruction = CreatePad(proto.shape(), operands(0), operands(1), - proto.padding_config()); + instruction = + CreatePad(shape, operands(0), operands(1), proto.padding_config()); break; case HloOpcode::kDynamicSlice: { TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -426,8 +453,8 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); - instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), - slice_sizes); + instruction = + CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes); break; } case HloOpcode::kGather: { @@ -443,7 +470,7 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_slice_sizes()) { gather_slice_sizes.push_back(bound); } - instruction = CreateGather(proto.shape(), operands(0), operands(1), + instruction = CreateGather(shape, operands(0), operands(1), *gather_dimension_numbers, gather_slice_sizes); break; } @@ -459,16 +486,15 @@ StatusOr> HloInstruction::CreateFromProto( auto scatter_dimension_numbers = absl::make_unique( proto.scatter_dimension_numbers()); - instruction = - CreateScatter(proto.shape(), operands(0), operands(1), operands(2), - computations(0), *scatter_dimension_numbers); + instruction = CreateScatter(shape, operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); break; } case HloOpcode::kIota: TF_RET_CHECK(proto.dimensions_size() == 1) << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); - instruction = CreateIota(proto.shape(), proto.dimensions(0)); + instruction = CreateIota(shape, proto.dimensions(0)); break; case HloOpcode::kDot: { TF_RET_CHECK(proto.has_dot_dimension_numbers()) @@ -480,33 +506,42 @@ StatusOr> HloInstruction::CreateFromProto( precision_config.mutable_operand_precision()->Resize( proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = absl::make_unique( - proto.shape(), operands(0), operands(1), - proto.dot_dimension_numbers(), precision_config); + shape, operands(0), operands(1), proto.dot_dimension_numbers(), + precision_config); break; } case HloOpcode::kDomain: { TF_RET_CHECK(proto.operand_ids_size() == 1) << "Domain instruction should have 1 operands but sees " << proto.operand_ids_size(); - TF_RET_CHECK(proto.has_domain_entry_sharding()) - << "Domain instruction must domain_entry_sharding"; - TF_RET_CHECK(proto.has_domain_exit_sharding()) - << "Domain instruction must domain_exit_sharding"; - TF_ASSIGN_OR_RETURN( - HloSharding entry_hlo_sharding, - HloSharding::FromProto(proto.domain_entry_sharding())); - TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding, - HloSharding::FromProto(proto.domain_exit_sharding())); + std::shared_ptr entry_hlo_sharding; + std::shared_ptr exit_hlo_sharding; + if (proto.has_domain_entry_sharding()) { + TF_ASSIGN_OR_RETURN( + HloSharding sharding, + HloSharding::FromProto(proto.domain_entry_sharding())); + entry_hlo_sharding = std::make_shared(sharding); + } + if (proto.has_domain_exit_sharding()) { + TF_ASSIGN_OR_RETURN( + HloSharding sharding, + HloSharding::FromProto(proto.domain_exit_sharding())); + exit_hlo_sharding = std::make_shared(sharding); + } instruction = absl::make_unique( - proto.shape(), operands(0), - absl::make_unique( - std::make_shared(entry_hlo_sharding)), - absl::make_unique( - std::make_shared(exit_hlo_sharding))); + shape, operands(0), + absl::make_unique(entry_hlo_sharding), + absl::make_unique(exit_hlo_sharding)); break; } + case HloOpcode::kGetDimensionSize: + TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.dimensions_size() == 1); + instruction = + CreateGetDimensionSize(shape, operands(0), proto.dimensions(0)); + break; default: { - instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (const int64 operand_id : proto.operand_ids()) { instruction->AppendOperand(instruction_map.at(operand_id)); } @@ -820,6 +855,16 @@ HloInstruction::CreateCollectivePermute( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } +/* static */ std::unique_ptr +HloInstruction::CreateAddDependency(HloInstruction* data_operand, + HloInstruction* token_operand) { + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kAddDependency, data_operand->shape())); + instruction->AppendOperand(data_operand); + instruction->AppendOperand(token_operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -976,6 +1021,14 @@ HloInstruction::CreateSelectAndScatter( broadcast_dimensions); } +/* static */ std::unique_ptr +HloInstruction::CreateGetDimensionSize(const Shape& shape, + HloInstruction* operand, + int64 dimension) { + return absl::make_unique(shape, operand, + dimension); +} + /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -1055,7 +1108,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span values) { return absl::make_unique(shape, dimension, keys, values); } @@ -1084,7 +1137,11 @@ void HloInstruction::set_single_sharding(const HloSharding& sharding) { void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { - if (sharding_ != nullptr) { + if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType( + shape_, derived_instruction->shape())) { + // Only copy sharding if the shape of the two instruction is compatible + // because copying it between differently shaped instructions can produce + // invalid shardings. derived_instruction->set_sharding(*sharding_); } else { derived_instruction->clear_sharding(); @@ -1142,6 +1199,15 @@ bool HloInstruction::HasSideEffect() const { shape, operands, custom_call_target, opaque); } +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque, operand_shapes_with_layout); +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( absl::Span elements) { std::vector element_shapes; @@ -1234,6 +1300,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kIota: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1337,6 +1404,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateAfterAll(new_operands); } break; + case HloOpcode::kAddDependency: + CHECK_EQ(new_operands.size(), 2); + clone = CreateAddDependency(new_operands[0], new_operands[1]); + break; } // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); @@ -1623,6 +1694,7 @@ bool HloInstruction::IdenticalSlowPath( // This opcode has complex or special behavior so just return false. case HloOpcode::kAfterAll: + case HloOpcode::kAddDependency: return false; // Remaining instructions with special values. @@ -1681,12 +1753,33 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kScatter: case HloOpcode::kDot: case HloOpcode::kDomain: + case HloOpcode::kGetDimensionSize: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } return false; } +uint64 HloInstruction::Hash() const { + using tensorflow::Hash64Combine; + + uint64 hash_value = Hash64Combine(0, static_cast(opcode())); + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape())); + + if (!IsCrossModuleAllReduce()) { + if (!operands().empty()) { + for (size_t i = 0; i < operands().size(); ++i) { + hash_value = Hash64Combine(hash_value, operand(i)->Hash()); + } + } + } + + hash_value = Hash64Combine(hash_value, InnerHash()); + return hash_value; +} + +uint64 HloInstruction::InnerHash() const { return 13; } + void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -1842,6 +1935,11 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } +HloInstruction* HloInstruction::while_init() const { + CHECK_EQ(HloOpcode::kWhile, opcode_); + return operands_[0]; +} + HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2156,7 +2254,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); - *proto.mutable_shape() = shape_; + *proto.mutable_shape() = shape_.ToProto(); for (const HloInstruction* operand : operands_) { proto.add_operand_ids(operand->unique_id()); } @@ -2404,8 +2502,12 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleDomain(this); case HloOpcode::kAfterAll: return visitor->HandleAfterAll(this); + case HloOpcode::kAddDependency: + return visitor->HandleAddDependency(this); case HloOpcode::kIota: return visitor->HandleIota(this); + case HloOpcode::kGetDimensionSize: + return visitor->HandleGetDimensionSize(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2563,36 +2665,6 @@ Status HloInstruction::AcceptWithOperandOrder( return Status::OK(); } -namespace { - -// Returns true if the given order is a topological sort of the instructions -// it contains. -bool OrderIsTopologicalSort(const std::vector& order) { - // Create a map from instruction to its position in 'order'. - std::unordered_map order_position; - for (int i = 0; i < order.size(); i++) { - if (!order_position.insert({order[i], i}).second) { - // Instruction order[i] is duplicated in the order. - return false; - } - } - // Verify that the operand of each instruction in the order is also in the - // order *and* the operand's position is earlier (defs are before uses for - // all ops). - for (auto* instruction : order) { - for (auto* operand : instruction->operands()) { - if (!ContainsKey(order_position, operand) || - order_position.at(operand) >= order_position.at(instruction)) { - return false; - } - } - } - - return true; -} - -} // namespace - Status HloInstruction::Accept( const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); @@ -2605,50 +2677,7 @@ Status HloInstruction::Accept( return this->Accept(&visitor); } -Status HloInstruction::AcceptOrdered( - DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; - TF_RET_CHECK(OrderIsTopologicalSort(order)); - - // Compute the predecessors of this instruction. - std::unordered_set predecessors; - TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { - predecessors.insert(instruction); - return Status::OK(); - })); - - for (auto* const_instruction : order) { - if (!ContainsKey(predecessors, const_instruction)) { - // Instruction is not a predecessors of 'this'. - continue; - } - - // The visitor can mark instructions as visited to skip particular - // instructions. - if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO %" << const_instruction->name() - << " as it was already visited."; - continue; - } - - // TODO(b/78350259): Eliminate const laundering. - HloInstruction* instruction = - const_cast(const_instruction); - - TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO %" << instruction->name(); - TF_RETURN_IF_ERROR(instruction->Visit(visitor)); - visitor->SetVisited(*instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); - } - - return visitor->FinishVisit(this); -} - -const Shape& HloInstruction::shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return shape_; -} +const Shape& HloInstruction::shape() const { return shape_; } std::vector HloInstruction::OperandIndices( const HloInstruction* operand) const { @@ -3005,6 +3034,16 @@ const PrecisionConfig& HloInstruction::precision_config() const { LOG(FATAL) << "Unimplemented method."; } +PrecisionConfig* HloInstruction::mutable_precision_config() { + if (auto* convolution = DynCast(this)) { + return convolution->mutable_precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->mutable_precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3047,6 +3086,10 @@ int64 HloInstruction::concatenate_dimension() const { return Cast(this)->concatenate_dimension(); } +int64 HloInstruction::dimension() const { + return Cast(this)->dimension(); +} + bool HloInstruction::IsRank2Transpose() const { auto transpose = DynCast(this); return transpose != nullptr && transpose->IsRank2Transpose(); @@ -3226,6 +3269,11 @@ absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } +void HloInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + return Cast(this)->set_all_reduce_id(all_reduce_id); +} + const ConvolutionDimensionNumbers& HloInstruction::convolution_dimension_numbers() const { if (auto convolution = DynCast(this)) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 374862c4b672bf4cb7c6e3dbc60392a1018520b7..a54716217d6bbc5c0601f5d9ff7bf4072a6b30f5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -28,11 +28,10 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -464,7 +463,7 @@ class HloInstruction { // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will // not be applied cross modules. // - // TODO(b/79737069): Rename this to AllReduce. + // TODO(b/117564385): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, absl::Span operands, HloComputation* reduce_computation, @@ -670,10 +669,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span dimensions); - // Creates a sort op, with a keys operand, and an optional values operand. + // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For @@ -734,6 +733,16 @@ class HloInstruction { const Shape& shape, absl::Span operands, absl::string_view custom_call_target, absl::string_view opaque = ""); + // Overload which constrains the layouts of the operand and result. 'shape' + // and 'operand_shapes_with_layout' must have layouts. + // 'operand_shapes_with_layout' must have a compatible element for each + // operand. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque = ""); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( @@ -758,6 +767,12 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr CreateToken(); + static std::unique_ptr CreateGetDimensionSize( + const Shape& shape, HloInstruction* operand, int64 dimension); + + static std::unique_ptr CreateAddDependency( + HloInstruction* data_operand, HloInstruction* token_operand); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -871,11 +886,15 @@ class HloInstruction { return false; } - // Use an explicit loop rather than ContainerEquals, because copying around - // std::functions may be too expensive in some cases. - for (size_t i = 0; i < operands().size(); ++i) { - if (!eq_operands(operand(i), other.operand(i))) { - return false; + // Two AllReduces are Identical if they have the same all_reduce_id. + // Their operands don't have to be Identical. + if (!IsCrossModuleAllReduce()) { + // Use an explicit loop rather than ContainerEquals, because copying + // around std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } } } @@ -886,6 +905,12 @@ class HloInstruction { return IdenticalSlowPath(other, eq_computations); } + // Generates a hash value of an HLO instruction. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO instructions, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const; + // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; @@ -945,16 +970,6 @@ class HloInstruction { Status Accept( const std::function& visitor_func) const; - // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain at least the set of instructions - // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). Instructions contained in 'order' which are not in the set of - // instructions rooted at this node are ignored. 'order' must also be a valid - // topological sort of these instructions (defs appear before uses) though - // need not be a DFS post-order. - Status AcceptOrdered(DfsHloVisitor* visitor, - const std::vector& order); - // Visit this instruction and only this instruction with the given visitor. template Status Visit(DfsHloVisitorBase* visitor); @@ -995,6 +1010,8 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); + HloInstruction* while_init() const; + // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -1255,6 +1272,7 @@ class HloInstruction { // superior. // Precondition: opcode must be kConvolution or kDot. const PrecisionConfig& precision_config() const; + PrecisionConfig* mutable_precision_config(); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1315,6 +1333,9 @@ class HloInstruction { // Delegates to HloConcatenateInstruction::concatenate_dimension. int64 concatenate_dimension() const; + // Delegates to HloGetDimensionSizeInstruction::dimension. + int64 dimension() const; + // Returns whether this instruction does a rank-2 transposition. bool IsRank2Transpose() const; @@ -1433,6 +1454,7 @@ class HloInstruction { // Delegates to HloAllReduceInstruction::all_reduce_id. absl::optional all_reduce_id() const; + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns data on the window in a windowed operation such as // convolution. @@ -1597,6 +1619,10 @@ class HloInstruction { const std::function& eq_computations) const; + // Generates a hash value specific to a particular type of an instruction. + // This function typically considers the inner root instruction. + virtual uint64 InnerHash() const; + // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, @@ -1635,7 +1661,7 @@ class HloInstruction { // members. The set enables fast membership testing and the vector enables // fast, stable iteration. std::vector users_; - std::unordered_set user_set_; + absl::flat_hash_set user_set_; // The set of control successors of this instruction. std::vector control_successors_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c1b7c3832b44b5d65b715dffa5211a5c92e17953..8048e332cb57747286758b75773b29ba154aa888 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,7 +39,7 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloVerifiedTestBase { +class HloInstructionTest : public HloTestBase { protected: Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } @@ -150,7 +151,7 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); @@ -187,7 +188,7 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -220,7 +221,7 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -255,7 +256,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -304,7 +305,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; @@ -326,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -374,7 +375,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. @@ -415,7 +416,7 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -450,7 +451,7 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -478,7 +479,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -515,7 +516,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); @@ -545,7 +546,7 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); @@ -610,7 +611,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; @@ -628,7 +629,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -646,7 +647,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -668,7 +669,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -691,7 +692,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -748,7 +749,7 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); @@ -816,7 +817,7 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -976,7 +977,7 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); int visit_num = 0; @@ -1005,7 +1006,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); @@ -1015,7 +1016,7 @@ TEST_F(HloInstructionTest, FullyElementwise) { } TEST_F(HloInstructionTest, MapIsElementwise) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); HloComputation::Builder builder(TestName()); HloComputation::Builder map_builder("id"); @@ -1066,7 +1067,7 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); @@ -1107,7 +1108,7 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); @@ -1150,7 +1151,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1191,7 +1192,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1203,7 +1204,7 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { } TEST_F(HloInstructionTest, FusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1225,7 +1226,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1329,7 +1330,7 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( @@ -1372,7 +1373,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1407,7 +1408,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*slice_sizes=*/{30, 29, 28, 27, 26})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), @@ -1442,7 +1443,7 @@ TEST_F(HloInstructionTest, StringifyScatter) { update_builder.AddInstruction( HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* update_computation = module->AddEmbeddedComputation(update_builder.Build()); @@ -1494,7 +1495,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1530,7 +1531,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); @@ -1586,7 +1587,7 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({dot, reshape}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 152d8eacdb591a31afcbbf7f9f01d51864c929f0..1ea02cf9c03866a598bec0e5356f0eb31ad27755 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -370,6 +370,11 @@ HloAllReduceInstruction::HloAllReduceInstruction( AppendComputation(reduce_computation); } +void HloAllReduceInstruction::set_all_reduce_id( + const absl::optional& all_reduce_id) { + all_reduce_id_ = all_reduce_id; +} + HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. @@ -600,11 +605,11 @@ std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) + absl::Span values) : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { AppendOperand(keys); - if (values) { - AppendOperand(values); + for (auto* value : values) { + AppendOperand(value); } } @@ -633,9 +638,8 @@ std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; - HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; return absl::make_unique(shape, dimensions(0), keys, - values); + new_operands.subspan(1)); } HloTransposeInstruction::HloTransposeInstruction( @@ -1368,6 +1372,10 @@ bool HloFusionInstruction::IdenticalSlowPath( other.fused_instructions_computation()); } +uint64 HloFusionInstruction::InnerHash() const { + return fused_instructions_computation()->Hash(); +} + std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { @@ -1611,7 +1619,7 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); - *proto.mutable_outfeed_shape() = outfeed_shape(); + *proto.mutable_outfeed_shape() = outfeed_shape().ToProto(); return proto; } @@ -1825,7 +1833,24 @@ HloCustomCallInstruction::HloCustomCallInstruction( : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), - feature_group_count_(1) { + feature_group_count_(1), + layout_constrained_(false) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque, + absl::Span operand_shapes_with_layout) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(true), + operand_shapes_with_layout_(operand_shapes_with_layout.begin(), + operand_shapes_with_layout.end()) { for (auto operand : operands) { AppendOperand(operand); } @@ -1843,6 +1868,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + if (layout_constrained()) { + proto.set_constrain_layout(true); + for (const Shape& shape : operand_shapes_with_layout_) { + *proto.add_operand_shapes_with_layout() = shape.ToProto(); + } + } return proto; } @@ -1870,6 +1901,14 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (!opaque_.empty()) { extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); } + if (layout_constrained()) { + std::vector shape_strings; + for (const Shape& shape : operand_shapes_with_layout_) { + shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape)); + } + extra.push_back(StrCat("operand_layout_constraints={", + StrJoin(shape_strings, ", "), "}")); + } return extra; } @@ -2305,18 +2344,57 @@ HloInstructionProto HloDomainInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); auto operand_side_sharding = dynamic_cast(operand_side_metadata_.get()); - if (operand_side_sharding) { + if (operand_side_sharding && operand_side_sharding->sharding() != nullptr) { *proto.mutable_domain_entry_sharding() = operand_side_sharding->sharding()->ToProto(); } auto user_side_sharding = dynamic_cast(user_side_metadata_.get()); - if (user_side_sharding) { + if (user_side_sharding && user_side_sharding->sharding() != nullptr) { *proto.mutable_domain_exit_sharding() = user_side_sharding->sharding()->ToProto(); } return proto; } + +HloGetDimensionSizeInstruction::HloGetDimensionSizeInstruction( + const Shape& shape, HloInstruction* operand, int64 dimension) + : HloInstruction(HloOpcode::kGetDimensionSize, shape), + dimension_(dimension) { + AppendOperand(operand); +} + +HloInstructionProto HloGetDimensionSizeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(dimension()); + return proto; +} + +std::vector HloGetDimensionSizeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + return {StrCat("dimensions={", dimension(), "}")}; +} + +bool HloGetDimensionSizeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return dimension() == casted_other.dimension(); +} + +std::unique_ptr +HloGetDimensionSizeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + if (new_operands.size() != 1) { + LOG(FATAL) << "expects 1 operand"; + } + return absl::make_unique( + shape, new_operands[0], dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e169604072a6d23c5e601fcbe00b7a7bf37a933d..b5c28137a145667a977d39c9d3c40c6d36a8436e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -252,6 +252,7 @@ class HloAllReduceInstruction : public HloCollectiveInstruction { } absl::optional all_reduce_id() const { return all_reduce_id_; } + void set_all_reduce_id(const absl::optional& all_reduce_id); // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -418,14 +419,19 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span values = {}); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns the sort dimension for this instruction - int64 sort_dimension() { return dimensions(0); } + int64 sort_dimension() const { return dimensions(0); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the key operand to this instruction. + const HloInstruction* keys() const { return operand(0); } + HloInstruction* mutable_keys() { return mutable_operand(0); } + // Returns the number of value operands. + int64 values_count() const { return operand_count() - 1; } private: std::vector ExtraAttributesToStringImpl( @@ -737,6 +743,8 @@ class HloFusionInstruction : public HloInstruction { const HloInstruction& other, const std::function& eq_computations) const override; + uint64 InnerHash() const override; + // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, @@ -949,6 +957,7 @@ class HloConvolutionInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -1053,10 +1062,19 @@ class HloSelectAndScatterInstruction : public HloInstruction { class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction(const Shape& shape, - absl::Span operands, - absl::string_view custom_call_target, - absl::string_view opaque); + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1085,6 +1103,16 @@ class HloCustomCallInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1106,6 +1134,11 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { @@ -1115,6 +1148,9 @@ class HloPadInstruction : public HloInstruction { const PaddingConfig& padding_config); // Returns the padding configuration for a pad node. const PaddingConfig& padding_config() const { return padding_config_; } + // Returns the padding value. + const HloInstruction* padding_value() const { return operand(1); } + HloInstruction* mutable_padding_value() { return mutable_operand(1); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1293,6 +1329,7 @@ class HloDotInstruction : public HloInstruction { // information but it is presumed that the alternate lowering is strictly // superior. const PrecisionConfig& precision_config() const { return precision_config_; } + PrecisionConfig* mutable_precision_config() { return &precision_config_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1353,6 +1390,33 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; }; + +class HloGetDimensionSizeInstruction : public HloInstruction { + public: + explicit HloGetDimensionSizeInstruction(const Shape& shape, + HloInstruction* operand, + int64 dimension); + + // Returns the dimension sizes or numbers associated with this instruction. + int64 dimension() const { return dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + int64 dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index d9be841dd751651ba029998fd062fcaec3691945..1390537101e95a08e4ba4eef7ae8d6059a40e916 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -163,6 +163,9 @@ TokKind HloLexer::LexToken() { current_ptr_ = comment_start; return TokKind::kError; } + if (current == kError) { + return TokKind::kError; + } } // Return no token for the comment. Keep lexing. continue; @@ -177,6 +180,9 @@ TokKind HloLexer::LexToken() { if (current == kEOF || current == '\n' || current == '\r') { break; } + if (current == kError) { + return TokKind::kError; + } current_ptr_++; } continue; @@ -204,7 +210,7 @@ TokKind HloLexer::LexIdentifier() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; + R"(^(\w*\d*)\[([\d,\s]*)\](?:(dense|sparse)?{([\d,\s]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 3e2f8bcd52f9043f161197756a2060b28dded1d9..d6a2b292a3916b2ff85f278cf5cb9f1567df88fa 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 5269cad94d35be3dd1c009588bbe422ff1533364..d28e79d41ad5d58a8881cfb80d488684af26564f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -237,8 +237,4 @@ void PrintTo(const HloInstruction* inst, ::std::ostream* os) { *os << (inst ? inst->ToString() : "nullptr"); } -void PrintTo(HloInstruction* inst, ::std::ostream* os) { - PrintTo(const_cast(inst), os); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 5502e565b6dfbaca6cfa2101950fb0a68c89771f..235efb19ce4ed28a5cd9fe5ca52ae5d8e9e5ba3d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -165,6 +165,7 @@ namespace opcode_matchers { } HLO_MATCHER(Abs); HLO_MATCHER(Add); +HLO_MATCHER(AllToAll); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); HLO_MATCHER(BatchNormGrad); @@ -178,7 +179,9 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); +HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -216,6 +219,7 @@ HLO_MATCHER(Remainder); HLO_MATCHER(Reshape); HLO_MATCHER(Reverse); HLO_MATCHER(Rng); +HLO_MATCHER(Scatter); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); @@ -381,7 +385,6 @@ std::vector Pointers(const Container& container) { // Tell GMock to print HloInstruction* by value, so error messages are nice. // Has to be in the same namespace as 'HloInstruction'. void PrintTo(const HloInstruction* inst, ::std::ostream* os); -void PrintTo(HloInstruction* inst, ::std::ostream* os); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 5cee865b7ad34eded1743d9d5455bb40febf6182..d2740bcce26f04c5d7c8b64cfdaea53e3c697855 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -73,7 +73,7 @@ class ListScheduler { // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. static StatusOr Run( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -98,7 +98,7 @@ class ListScheduler { // comparison operators. using Priority = std::pair; - ListScheduler(const HloComputation& computation, + ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -111,7 +111,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { absl::flat_hash_set instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( @@ -126,13 +126,13 @@ class ListScheduler { // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto* instruction : computation.instructions()) { + for (auto* instruction : computation->instructions()) { for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } @@ -141,7 +141,7 @@ class ListScheduler { // Buffers live out of the computation have an implicit use at the end of // the computation. for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) + points_to_analysis.GetPointsToSet(computation->root_instruction()) .CreateFlattenedSet()) { ++unscheduled_use_count_[live_out_buffer]; } @@ -157,7 +157,7 @@ class ListScheduler { // HloInstruction, plus some cached metadata, saved for the purposes of making // BytesFreedIfScheduled fast. struct ReadyListEntry { - const HloInstruction* instruction; + HloInstruction* instruction; // The total size of all buffers defined by this instruction. int64 bytes_defined; @@ -171,7 +171,7 @@ class ListScheduler { }; // Creates a ReadyListEntry for the given instruction. - ReadyListEntry MakeReadyListEntry(const HloInstruction* instruction) { + ReadyListEntry MakeReadyListEntry(HloInstruction* instruction) { ReadyListEntry entry; entry.instruction = instruction; @@ -250,13 +250,13 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. absl::flat_hash_map unscheduled_pred_count; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { + for (HloInstruction* user : instruction->users()) { unscheduled_pred_count[user]++; } - for (const HloInstruction* succ : instruction->control_successors()) { + for (HloInstruction* succ : instruction->control_successors()) { unscheduled_pred_count[succ]++; } } @@ -275,7 +275,7 @@ class ListScheduler { ready_instructions[inst] = it; }; - for (auto* instruction : computation_.instructions()) { + for (auto* instruction : computation_->instructions()) { if (instruction->operands().empty() && instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); @@ -287,7 +287,7 @@ class ListScheduler { // schedule. auto best_it = ready_queue.end(); --best_it; - const HloInstruction* best = best_it->second.instruction; + HloInstruction* best = best_it->second.instruction; VLOG(2) << "Schedule instruction: " << best->ToShortString() << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); @@ -348,13 +348,13 @@ class ListScheduler { } } } - CHECK_EQ(schedule.size(), computation_.instruction_count()); - CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); + CHECK_EQ(schedule.size(), computation_->instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_->instruction_count()); return schedule; } - const HloComputation& computation_; + HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; // Computations are analyzed in post-order. When scheduling an instruction @@ -386,13 +386,13 @@ int64 SumLogicalBufferSizes( } StatusOr ScheduleComputationHelper( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, const absl::flat_hash_map& memory_by_computation) { - VLOG(2) << "Computation: " << computation.name(); + VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, size_function, memory_by_computation); @@ -404,17 +404,17 @@ StatusOr ScheduleComputationHelper( } // namespace StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->instruction_count(); + int64 total_hlos = computation->parent()->instruction_count(); absl::flat_hash_map extra_users; absl::flat_hash_map total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + for (const HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; total_sizes[hlo] = 0; @@ -448,8 +448,8 @@ StatusOr DFSMemoryScheduler( total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } - CHECK_EQ(extra_users.size(), computation.instruction_count()); - CHECK_EQ(total_sizes.size(), computation.instruction_count()); + CHECK_EQ(extra_users.size(), computation->instruction_count()); + CHECK_EQ(total_sizes.size(), computation->instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -459,7 +459,7 @@ StatusOr DFSMemoryScheduler( sequence.push_back(hlo); return Status::OK(); }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + TF_RETURN_IF_ERROR(computation->AcceptWithOperandOrder( &visitor, [&extra_users, &total_sizes](const HloInstruction* a, const HloInstruction* b) { if (extra_users[a] != extra_users[b]) { @@ -470,12 +470,12 @@ StatusOr DFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instruction_count()); + CHECK_EQ(sequence.size(), computation->instruction_count()); return sequence; } // namespace xla StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -485,16 +485,16 @@ StatusOr ListMemoryScheduler( } StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); + return HloInstructionSequence(computation->MakeInstructionPostOrder()); } StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -513,7 +513,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, HeapSimulator::MinimumMemoryForComputation( - computation, list_sequence, points_to_analysis, + *computation, list_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); @@ -522,7 +522,7 @@ StatusOr DefaultMemoryScheduler( size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, HeapSimulator::MinimumMemoryForComputation( - computation, dfs_sequence, points_to_analysis, + *computation, dfs_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); @@ -532,7 +532,7 @@ StatusOr DefaultMemoryScheduler( memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, HeapSimulator::MinimumMemoryForComputation( - computation, post_order_sequence, points_to_analysis, + *computation, post_order_sequence, points_to_analysis, size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -555,17 +555,17 @@ StatusOr DefaultMemoryScheduler( } StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); + HloSchedule schedule(module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(&module)); + TuplePointsToAnalysis::Run(module)); absl::flat_hash_map memory_by_computation; - for (const auto* computation : module.MakeComputationPostOrder()) { + for (auto* computation : module->MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( - *computation, *points_to_analysis, size_function, + computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( @@ -583,11 +583,11 @@ StatusOr ScheduleModule( } StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function) { - CHECK(!computation.IsFusionComputation()); + CHECK(!computation->IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); + TuplePointsToAnalysis::Run(computation->parent())); absl::flat_hash_map empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); @@ -600,7 +600,24 @@ HloMemoryScheduler::HloMemoryScheduler( StatusOr HloMemoryScheduler::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, size_function_, algorithm_)); + ScheduleModule(module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; +} + +StatusOr HloTrivialScheduler::Run(HloModule* module) { + HloSchedule schedule(module); + for (HloComputation* computation : module->MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + HloInstructionSequence& computation_sequence = + schedule.GetOrCreateSequence(computation); + TF_RETURN_IF_ERROR(computation->Accept( + [&computation_sequence](HloInstruction* instruction) { + computation_sequence.push_back(instruction); + return Status::OK(); + })); + } + } TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index a4c1d3db8170a1725043def576f913e09b352e5d..7227bfb27c74758d2b79e404afc9eb97a1ca894d 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -36,14 +36,14 @@ namespace xla { // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function( - const HloComputation&, const TuplePointsToAnalysis&, + HloComputation*, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const absl::flat_hash_map&)> MemorySchedulerAlgorithm; // List scheduler StatusOr ListMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -51,7 +51,7 @@ StatusOr ListMemoryScheduler( // DFS-order scheduler StatusOr DFSMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -59,7 +59,7 @@ StatusOr DFSMemoryScheduler( // Naive Post Order scheduler StatusOr PostOrderMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -69,7 +69,7 @@ StatusOr PostOrderMemoryScheduler( // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. StatusOr DefaultMemoryScheduler( - const HloComputation& computation, + HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const absl::flat_hash_map& @@ -79,13 +79,13 @@ StatusOr DefaultMemoryScheduler( // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr ScheduleModule( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + HloModule* module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. StatusOr ScheduleComputation( - const HloComputation& computation, + HloComputation* computation, const LogicalBuffer::SizeFunction& size_function); // A pass which schedules the HLO instructions in a module. The HloModule's @@ -108,6 +108,15 @@ class HloMemoryScheduler : public HloModulePass { MemorySchedulerAlgorithm algorithm_; }; +// A pass which produces a naive, but correct schedule. The schedule is produced +// using a DFS traversal of the graph with no attempt to minimize memory use. +class HloTrivialScheduler : public HloModulePass { + public: + absl::string_view name() const override { return "hlo-trivial-scheduler"; } + + StatusOr Run(HloModule* module) override; +}; + // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. class HloDescheduler : public HloModulePass { diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 214119fba881c4411a262cd4227b5cc49cef0d14..bc0d7e2bc00eab014f2660c95a51b966642eaee9 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -65,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto sub = builder.AddInstruction( HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); HloMemoryScheduler scheduler([](const BufferValue& buffer) { @@ -78,7 +78,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK(module->schedule().Verify()); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = module->schedule().sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -124,9 +124,9 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - const std::vector& sequence = + const std::vector& sequence = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); @@ -172,15 +172,16 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, abs_abs2)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -218,19 +219,19 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto* computation = module->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); + ScheduleModule( + module.get(), + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -242,7 +243,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { } TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); // param != 0 @@ -252,7 +253,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { HloInstruction::CreateParameter(0, r1f32, "cond_param")); HloInstruction* zero_vector = cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0, 0, 0, 0}}))); + LiteralUtil::CreateR1({0, 0, 0, 0}))); cond_builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); @@ -284,7 +285,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { }; TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + ScheduleModule(module.get(), size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -309,5 +310,40 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { .ValueOrDie()); } +TEST_F(HloSchedulingTest, TrivialScheduler) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + EXPECT_FALSE(module->has_schedule()); + TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + + // Verify that a clone of the module also has a schedule. + std::unique_ptr clone = module->Clone(); + ASSERT_TRUE(clone->has_schedule()); + TF_ASSERT_OK(clone->schedule().Verify()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 93e04eb3db47ba3dadfbd412733997b92c07da92..fe8371384c0fa3900a9022f101ff0b296439cf16 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -41,18 +41,6 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} -StatusOr HloModule::LaunderConstInstructionFromModule( - const HloInstruction* hlo) { - if (hlo == nullptr) { - return nullptr; - } - - TF_RET_CHECK(hlo->GetModule() == this); - - // TODO(b/78350259): Eliminate const laundering. - return const_cast(hlo); -} - Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); TF_RETURN_IF_ERROR(schedule.Verify()); @@ -73,6 +61,8 @@ HloComputation* HloModule::AddComputationInternal( config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } + input_output_alias_config_ = HloInputOutputAliasConfig( + entry_computation_->root_instruction()->shape()); } if (uniquify_identifiers) { @@ -244,14 +234,16 @@ HloModuleProto HloModule::ToProto() const { proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { HloComputationProto computation_proto = computation->ToProto(); - if (computation->name() == entry_computation_->name()) { - *proto.mutable_program_shape() = computation_proto.program_shape(); - } proto.add_computations()->Swap(&computation_proto); } if (has_schedule()) { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } + *proto.mutable_host_program_shape() = + entry_computation_layout().ComputeProgramShape().ToProto(); + *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); + *proto.mutable_dynamic_parameter_binding() = + dynamic_parameter_binding().ToProto(); return proto; } @@ -263,9 +255,9 @@ StatusOr> HloModule::CreateFromProto( // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_RET_CHECK(proto.has_program_shape()) + TF_RET_CHECK(proto.has_host_program_shape()) << "No program shape found in the proto"; - const auto& expected_program_shape = proto.program_shape(); + ProgramShape expected_program_shape(proto.host_program_shape()); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -328,8 +320,17 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(module->entry_computation_ != nullptr); + TF_ASSIGN_OR_RETURN( + module->input_output_alias_config_, + HloInputOutputAliasConfig::CreateFromProto( + entry->ComputeProgramShape().result(), proto.input_output_alias())); + // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. + TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_, + DynamicParameterBinding::CreateFromProto( + proto.dynamic_parameter_binding())); + absl::flat_hash_set computation_names; absl::flat_hash_set instruction_names; absl::flat_hash_set computation_ids; @@ -366,11 +367,11 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( const HloModuleProto& module, const DebugOptions& debug_options) { - TF_RET_CHECK(module.has_program_shape()) + TF_RET_CHECK(module.has_host_program_shape()) << "No program shape found in the proto"; - const auto& program_shape = module.program_shape(); + ProgramShape program_shape(module.host_program_shape()); - HloModuleConfig module_config(program_shape); + HloModuleConfig module_config(ProgramShape{program_shape}); module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is @@ -558,12 +559,34 @@ std::vector HloModule::MakeNonfusionComputations() const { } std::unique_ptr HloModule::Clone(const string& suffix) const { + return Clone(config(), suffix); +} + +std::unique_ptr HloModule::Clone(const HloModuleConfig& config, + const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = absl::make_unique(name_ + "-" + suffix, config_); + auto module = absl::make_unique( + absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); + + if (has_schedule() && schedule().Verify().ok()) { + HloSchedule clone_schedule(module.get()); + for (HloComputation* computation : computations()) { + if (schedule().is_computation_scheduled(computation)) { + HloInstructionSequence& clone_sequence = + clone_schedule.GetOrCreateSequence( + context.GetComputation(computation)); + for (const HloInstruction* instruction : + schedule().sequence(computation).instructions()) { + clone_sequence.push_back(context.GetInstruction(instruction)); + } + } + } + TF_CHECK_OK(module->set_schedule(std::move(clone_schedule))); + } return module; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 735804e827afd77e2b7f2a4a7d490ee6f5ee7b4f..7b9cbf9a53a2201b1312405bbd7ed2b88f65c9be 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -28,9 +28,11 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" +#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" @@ -92,6 +94,8 @@ class HloModule { // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + std::unique_ptr Clone(const HloModuleConfig& config, + const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all // the called computations as well. If the clone context is specified, it @@ -99,14 +103,18 @@ class HloModule { HloComputation* DeepCloneComputation(HloComputation* computation, HloCloneContext* context = nullptr); - // Return a pointer to the entry computation of the module.. - const HloComputation* entry_computation() const { + // Return a pointer to the entry computation of the module. + HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } - HloComputation* entry_computation() { + + // Returns the root instruction shape of entry computation. + // + // Precondition: entry_computation_ is not nullptr. + const Shape& result_shape() const { CHECK_NE(nullptr, entry_computation_); - return entry_computation_; + return entry_computation()->root_instruction()->shape(); } // Creates the ComputationLayout which describes the current status of the HLO @@ -124,6 +132,12 @@ class HloModule { return config_.entry_computation_layout(); } + // Generates a hash value of an HLO module. Hash considers + // information on opcode, shape, operands, and typically a root instruction. + // This function returns the same hash value for equivalent HLO modules, + // with respect to HloInstruction::Identical() method. + uint64 Hash() const { return entry_computation()->Hash(); } + // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -212,33 +226,29 @@ class HloModule { return result; } - // Returns the number of unique intruction ids given out. All ids up to - // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) - int NumUniqueInstructionIds() const { return next_unique_id_; } + // input_output_alias_config indicates the list of aliased buffers that are + // expected from the module. + HloInputOutputAliasConfig& input_output_alias_config() { + return input_output_alias_config_; + } + const HloInputOutputAliasConfig& input_output_alias_config() const { + return input_output_alias_config_; + } + + // DynamicParameterBinding holds the list of bindings that indicates which + // parameter dimensions are dynamic and which parameters represent their + // runtime value. + DynamicParameterBinding& dynamic_parameter_binding() { + return dynamic_parameter_binding_; + } + const DynamicParameterBinding& dynamic_parameter_binding() const { + return dynamic_parameter_binding_; + } // Returns an id that is unique to this module across all modules created over // the lifetime of this process. int unique_id() const { return unique_id_; } - // Returns a non-const version of the passed-in const HloInstruction*. This is - // safe on the argument that if you have a non-const module, then you can - // access all instructions in the module as non-const. - // - // Returns an error if the passed-in instruction is not from this module, - // except that it is allowed to pass in a null pointer. - // - // TODO(b/78350259): Eliminate const laundering. The argument above is not - // reliable since at any time someone could add or discover a way for a - // non-const module to transitively contain a const HloInstruction. The - // reliable way to do this would be to create a const laundering map from a - // module, mapping each encountered HloInstruction to its non-const version - // and then look up each instruction in need of laundering in that map, but - // this is much more expensive and complicated. This returns a Status instead - // of doing a CHECK-failure in part to make it strongly apparent that this is - // something that can fail. - StatusOr LaunderConstInstructionFromModule( - const HloInstruction* hlo); - // Sets the schedule of the module to the given schedule. Status set_schedule(HloSchedule schedule); @@ -284,6 +294,13 @@ class HloModule { // sequential order of instructions for each non-fusion computation in the // module. absl::optional schedule_; + + // alias_config indicates the alias information of input/output buffers that + // are expected from the module. + HloInputOutputAliasConfig input_output_alias_config_; + + // Bindings for dynamic parameter mapping. + DynamicParameterBinding dynamic_parameter_binding_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc index f9b56ef4643f2ca88e56456ae6c990161adb5085..69d57c3f146f17ebbddef1ed972b92a587d67be7 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -17,9 +17,8 @@ limitations under the License. namespace xla { -HloModuleGroup::HloModuleGroup(absl::string_view name, - std::unique_ptr module) - : name_(name) { +HloModuleGroup::HloModuleGroup(std::unique_ptr module) + : name_(module->name()) { push_back(std::move(module)); } @@ -31,6 +30,14 @@ HloModuleGroup::HloModuleGroup(absl::string_view name, } } +HloModuleGroup::HloModuleGroup( + absl::string_view name, std::vector>&& modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + std::vector> HloModuleGroup::ConsumeModules() { std::vector> ret_modules = std::move(modules_); @@ -83,6 +90,12 @@ void HloModuleGroup::push_back(std::unique_ptr module) { module_ptrs_.push_back(modules_.back().get()); } +void HloModuleGroup::ReplaceModule(int index, + std::unique_ptr module) { + modules_.at(index) = std::move(module); + module_ptrs_.at(index) = modules_.at(index).get(); +} + std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { out << group.ToString(); return out; diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h index 7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0..c4b10f3b22ab2aa0a346cae4e2d0d87496722368 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -35,11 +35,13 @@ class HloModuleGroup { explicit HloModuleGroup(absl::string_view name) : name_(name) {} // Construct a module group containing a single module. - HloModuleGroup(absl::string_view name, std::unique_ptr module); + explicit HloModuleGroup(std::unique_ptr module); // Construct a module group containing any number of modules. HloModuleGroup(absl::string_view name, absl::Span> modules); + HloModuleGroup(absl::string_view name, + std::vector>&& modules); // Returns the modules contained in the group. const std::vector& modules() const { return module_ptrs_; } @@ -50,11 +52,16 @@ class HloModuleGroup { // Add a module to the back of vector of modules in the group. void push_back(std::unique_ptr module); + // Replaces the existing module at the given index with the given module. The + // existing module is discarded. + void ReplaceModule(int index, std::unique_ptr module); + // Moves all modules from the group into the returned vector. After this // method runs, the module group will be empty. std::vector> ConsumeModules(); string name() const { return name_; } + string ToString() const; // Serialize the module group to/from a proto. @@ -63,6 +70,12 @@ class HloModuleGroup { const HloModuleGroupProto& proto, absl::Span module_configs); + // Returns the number of modules in the module group. + int size() const { return modules_.size(); } + + // Returns true if there are no modules in the module group. + bool empty() const { return modules_.empty(); } + private: string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 83352ef91b35b61ee2560b1488ee2ecdff6bea0a..b4aac4c8076cb69647d42c6243bc969d06d0709e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { } /* static */ StatusOr> -HloModuleGroupMetadata::Build(const std::vector& modules) { +HloModuleGroupMetadata::Build(absl::Span modules) { auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 0311b7320721e98ab80ff0a28adb2e8fe53cee9b..928df0f5a7444ad877961a5de970c752e1d024da 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -102,14 +102,14 @@ class HloModuleGroupMetadata { HloInstruction* recv_done = nullptr; }; - explicit HloModuleGroupMetadata(const std::vector& modules) - : modules_(modules) {} + explicit HloModuleGroupMetadata(absl::Span modules) + : modules_(modules.begin(), modules.end()) {} ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. static StatusOr> Build( - const std::vector& modules); + absl::Span modules); // Returns true if the instruction is one of the 4 channel instructions (Send, // Recv, SendDone, RecvDone). @@ -274,7 +274,7 @@ class HloModuleGroupMetadata { int64 max_channel_id_ = -1; // The modules that this metadata was built from. - const std::vector& modules_; + const std::vector modules_; absl::flat_hash_map> points_to_analyses_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index b7b12cb72b8df4610b964fb842da78e160d22d9f..5a9a86af5649bf240bb5de6d30fc80b0f6a58eba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -46,7 +46,7 @@ ENTRY %entry (x: f32[], y: f32[]) -> f32[] { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(text)); - HloModuleGroup group(TestName(), std::move(module)); + HloModuleGroup group(std::move(module)); EXPECT_EQ(group.modules().size(), 1); EXPECT_THAT( diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 39f38b417ab0e8b54864176d8d1e0ad1a422eca6..620cb7e01ad1a060915f5b73474f6950ab18122a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -63,7 +63,7 @@ class HloModuleTest : public HloTestBase { TEST_F(HloModuleTest, OneComputationPostOrder) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(CreateConstantComputation()); EXPECT_THAT(module->MakeComputationPostOrder(), @@ -72,7 +72,7 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { TEST_F(HloModuleTest, TwoComputationsPostOrder) { // Create a module with two unconnected computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEntryComputation(CreateConstantComputation()); auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); @@ -88,7 +88,7 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { TEST_F(HloModuleTest, CloneTest) { // Create and copy a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -111,7 +111,7 @@ TEST_F(HloModuleTest, CloneTest) { } TEST_F(HloModuleTest, CloneHasFusion) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); // Create the fused computation. HloComputation* fused_computation; @@ -154,7 +154,7 @@ TEST_F(HloModuleTest, CloneHasFusion) { TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation1 = module->AddEmbeddedComputation(CreateConstantComputation()); auto computation2 = @@ -174,7 +174,7 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { TEST_F(HloModuleTest, LargeConstantToString) { // Create a module with a single computation. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder("Constant"); std::vector values(16, 42.0); builder.AddInstruction( @@ -194,8 +194,8 @@ TEST_F(HloModuleTest, LargeConstantToString) { } TEST_F(HloModuleTest, UniqueModuleId) { - auto module_a = CreateNewModule(); - auto module_b = CreateNewModule(); + auto module_a = CreateNewVerifiedModule(); + auto module_b = CreateNewVerifiedModule(); EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e6bfb8025d4bfeba1d334d1f946e33841a2da092..127cfd165a5d8229cac3035f56a66f1bcfa734f3 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -47,6 +47,8 @@ namespace xla { #define HLO_OPCODE_LIST(V) \ V(kAbs, "abs") \ V(kAdd, "add") \ + V(kAddDependency, "add-dependency") \ + V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all") \ V(kAtan2, "atan2") \ V(kBatchNormGrad, "batch-norm-grad") \ @@ -83,7 +85,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ - V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ + V(kGetDimensionSize, "get-dimension-size") \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kImag, "imag") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 23d41d91d6969ddf9062507e926ae39c1e1315d4..ca6a154809be46d6a0305c29e2b89219de408019 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -334,7 +334,7 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. for (auto* computation : module->MakeNonfusionComputations()) { - predecessors_.emplace(computation, computation->ComputeReachability()); + predecessors_.emplace(computation, HloReachabilityMap::Build(computation)); } } @@ -356,8 +356,7 @@ void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. TF_DCHECK_OK(schedule_.Verify()); for (const auto& computation_sequence : schedule_.sequences()) { - const std::vector& order = - computation_sequence.second.instructions(); + const auto& order = computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { InsertOrDie(&order_position_, order[i], i); } @@ -374,11 +373,10 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( return order_position_.at(a) < order_position_.at(b); } -const std::vector* -SequentialHloOrdering::SequentialOrder( +const HloInstructionSequence* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { return schedule_.is_computation_scheduled(&computation) - ? &schedule_.sequence(&computation).instructions() + ? &schedule_.sequence(&computation) : nullptr; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 66313492eb2dd10ac9a6000639ddb8991b367c0f..a07214c22c0989a438f12219e136a7e76ee0dcce 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +65,7 @@ class HloOrdering { // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. - virtual const std::vector* SequentialOrder( + virtual const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const = 0; // Return the call graph of the module used to compute ordering. @@ -96,7 +97,7 @@ class PredecessorHloOrdering : public HloOrdering { // Returns nullptr indicating the computation does not have a sequential // ordering. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override { return nullptr; } @@ -185,7 +186,7 @@ class SequentialHloOrdering : public HloOrdering { ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. - const std::vector* SequentialOrder( + const HloInstructionSequence* SequentialOrder( const HloComputation& computation) const override; string ToString() const override; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index b045adc9640ac0ca8cf4a127fea2fbfcbb1aaf3f..3ca77e60cd5275c22eb0e338cd5437fc44b49958 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -53,7 +53,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // %c = Constant(42.0f) // // This results in a diamond-shaped callgraph. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder_c = HloComputation::Builder("C"); @@ -126,7 +126,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { // %constant = Constant(1.0) // return While(%constant, body, condition) // - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -176,7 +176,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { // Entry parameter should always be defined before other instruction. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -209,7 +209,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // %while = While(%constant, body, condition) // %add = Add(%constant, %while) // - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto body_builder = HloComputation::Builder("body"); @@ -407,7 +407,7 @@ TEST_F(HloOrderingTest, // %dead = Constant(123.0) // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto builder = HloComputation::Builder(TestName()); @@ -455,7 +455,7 @@ TEST_F(HloOrderingTest, // ROOT %call = call({%c}), subcomputation // // %root should interfere with %dead. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); auto subbuilder = HloComputation::Builder(TestName() + ".sub"); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index dd62988bccf7a0b2daa0bd39fc642452c768fceb..9b5bb5d0bd6af104ef62eaa5d3e53cedbe0213d3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -47,11 +47,11 @@ const double kF16max = 65504; // Creates and returns a schedule created using the order of the instructions in // the HloComputation::instructions() vectors in the module. -HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { +HloSchedule ScheduleFromInstructionOrder(HloModule* module) { HloSchedule schedule(module); - for (const HloComputation* computation : module->computations()) { + for (HloComputation* computation : module->computations()) { if (!computation->IsFusionComputation()) { - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { schedule.GetOrCreateSequence(computation).push_back(instruction); } } @@ -108,7 +108,7 @@ class HloParser { bool ParseInstructionList(HloComputation** computation, const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); @@ -174,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -240,6 +241,7 @@ class HloParser { bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); + bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -416,6 +418,18 @@ std::pair* HloParser::FindInstruction( } return create_missing_instruction_(name, *shape); } + + if (instr != nullptr && shape.has_value() && + !ShapeUtil::Compatible(instr->first->shape(), shape.value())) { + Error( + lexer_.GetLoc(), + StrCat("The declared operand shape ", + ShapeUtil::HumanStringWithLayout(shape.value()), + " is not compatible with the shape of the operand instruction ", + ShapeUtil::HumanStringWithLayout(instr->first->shape()), ".")); + return nullptr; + } + return instr; } @@ -594,10 +608,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } - return ParseInstruciontRhs(builder, name, name_loc); + return ParseInstructionRhs(builder, name, name_loc); } -bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, +bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, const string& name, LocTy name_loc) { Shape shape; HloOpcode opcode; @@ -836,9 +850,16 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, } break; } + case HloOpcode::kAddDependency: { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateAddDependency(operands[0], operands[1])); + break; + } case HloOpcode::kSort: { - auto loc = lexer_.GetLoc(); - optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; @@ -846,20 +867,10 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - switch (operands.size()) { - case 1: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), /*keys=*/operands[0])); - break; - case 2: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], /*values=*/operands[1])); - break; - default: - return Error(loc, StrCat("expects either 1 or 2 operands, but has ", - operands.size(), " operands")); - } + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], + /*values=*/absl::Span(operands).subspan(1))); break; } case HloOpcode::kTuple: { @@ -1099,8 +1110,8 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, absl::Span(operands).subspan( 0, operands.size() / 2), /*init_values=*/ - absl::Span(operands).subspan( - operands.size() / 2, operands.size()), + absl::Span(operands).subspan(operands.size() / + 2), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1341,6 +1352,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; @@ -1349,12 +1361,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, - opaque.has_value() ? *opaque : "")); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } @@ -1504,6 +1556,18 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); + case HloOpcode::kGetDimensionSize: + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateGetDimensionSize( + shape, operands[0], (*dimensions)[0])); + break; } instruction->SetAndSanitizeName(name); @@ -1763,6 +1827,10 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value, case U64: return SetValueInLiteralHelper(value, linear_index, literal); + case PRED: + // Bool type literals with rank >= 1 are printed in 0s and 1s. + return SetValueInLiteralHelper(static_cast(value), + linear_index, literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); @@ -2017,14 +2085,13 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { } if (lexer_.GetKind() == TokKind::kw_true || lexer_.GetKind() == TokKind::kw_false) { - // TODO(congliu): bool type literals with rank >= 1 are actually - // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, linear_index++, literal)) { return false; } lexer_.Lex(); - } else if (primitive_util::IsIntegralType(shape.element_type())) { + } else if (primitive_util::IsIntegralType(shape.element_type()) || + shape.element_type() == PRED) { LocTy loc = lexer_.GetLoc(); tensorflow::int64 value; if (!ParseInt64(&value)) { @@ -2533,6 +2600,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2653,7 +2729,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - std::vector split1 = absl::StrSplit(str, "_"); + std::vector split1 = absl::StrSplit(str, '_'); if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; @@ -2825,6 +2901,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2832,23 +2925,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, @@ -2933,7 +3018,8 @@ bool HloParser::ParseShape(Shape* result) { } if (lexer_.GetKind() != TokKind::kShape) { - return TokenError("expects shape"); + return TokenError(absl::StrCat("expected shape, saw ", + TokKindToString(lexer_.GetKind()))); } *result = lexer_.GetShapeVal(); lexer_.Lex(); @@ -3324,7 +3410,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { // e.g. // // f32[10] fusion(...), calls={...} - if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) { return false; } } else { diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 81eeb9f13bf7f06123c0b35e9f3352c197866a7a..d830fa61438239005875f785f85cf2486123ebc9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -44,7 +44,9 @@ Status ParseHloString(absl::string_view str, HloModule* module); // creates a HloModule with default config. StatusOr> ParseHloString(absl::string_view str); -// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, +// e.g., "{replicated}". StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). @@ -55,10 +57,6 @@ StatusOr ParseWindow(absl::string_view str); StatusOr ParseConvolutionDimensionNumbers( absl::string_view str); -// ParseHloString sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(absl::string_view str); - // Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". StatusOr ParsePaddingConfig(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 255123d331c91b1c862980b9248afe9a03d564c8..ab71f011ac9d77d00ddfb41aca7a224d26d416b7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -21,7 +21,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,7 +30,7 @@ limitations under the License. namespace xla { namespace { -namespace op = ::xla::testing::opcode_matchers; +namespace m = ::xla::match; using absl::string_view; struct TestData { @@ -75,6 +76,18 @@ ENTRY %constant_pred () -> pred[] { )" }, +// pred array constant +{ +"ConstantPredArray", +R"(HloModule module + +ENTRY %constant_pred_array () -> pred[2,3] { + ROOT %constant = pred[2,3]{1,0} constant(pred[2,3] { { 0, 1, 0 }, { 1, 0, 1 } }) +} + +)" +}, + // s32 constant { "ConstantS32", @@ -183,7 +196,7 @@ ENTRY %add_constants () -> f32[] { R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { - ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { {1}, {2} }, {2, 42} )) } )" @@ -575,7 +588,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { - %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } }) %constant.1 = f32[2]{0} constant({2, 3}) %constant.2 = f32[2]{0} constant({1, 2}) ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 @@ -802,6 +815,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -966,6 +1016,21 @@ ENTRY Sort { ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0} } +)" +}, +// Sort (Key, Value, Value, Value) +{ +"SortManyValues", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values.0 = s32[1024,16]{0,1} parameter(1) + values.1 = u32[1024,16]{0,1} parameter(2) + values.2 = f32[1024,16]{0,1} parameter(3) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} +} + )" }, // Conditional @@ -1086,6 +1151,25 @@ ENTRY CrossReplicaSumWithSubgroups { ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } +)" +}, +// cross-replica-sum with all-reduce-id +{ +"CrossReplicaSumAllReduce", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + crs.1 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add + ROOT crs.0 = f32[8]{0} cross-replica-sum(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add +} + )" }, // all-to-all @@ -1158,7 +1242,38 @@ ENTRY Sort { } )" + }, +// AfterAll with multiple operands +{ +"AfterAllWithMultipleOperands", +R"(HloModule AfterAllWithMultipleOperands + +ENTRY AfterAllWithMultipleOperands { + p0 = f32[] parameter(0) + token0 = token[] after-all() + token1 = token[] after-all() + ROOT after-all = token[] after-all(p0, token0, token1) +} + +)" +}, +// AddDependency +// A dependency chain is created from 'neg' to 'exp' using tokens. +{ +"AddDependency", +R"(HloModule AddDependency + +ENTRY AddDependency { + p = f32[] parameter(0) + neg = f32[] negate(p) + token = token[] after-all(neg) + p_after_token = f32[] add-dependency(p, token) + exp = f32[] exponential(p_after_token) + ROOT sum = f32[] add(neg, exp) } + +)" +}, }); // clang-format on } @@ -1779,7 +1894,8 @@ ENTRY ReduceR3ToR2 { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); ASSERT_NE(module->entry_computation(), nullptr); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Reduce())); } TEST_F(HloParserTest, ParseSharding) { @@ -1839,7 +1955,7 @@ TEST(HloParserSingleOpTest, SingleOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { @@ -1867,7 +1983,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); } TEST(HloParserSingleOpTest, CanonicalOp) { @@ -1876,7 +1992,7 @@ TEST(HloParserSingleOpTest, CanonicalOp) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1)))); EXPECT_EQ( computation->root_instruction()->ToString(HloPrintOptions::Canonical()), text); @@ -1930,7 +2046,11 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Fusion(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Op() + .WithOpcode(HloOpcode::kFusion) + .WithNumOperands(2) + .WithOperand(0, m::Parameter(0)) + .WithOperand(1, m::Parameter(1)))); } TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { @@ -1974,7 +2094,7 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), - op::Convolution(op::Parameter(0), op::Parameter(1))); + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); auto* convolution = Cast(computation->root_instruction()); EXPECT_EQ(convolution->feature_group_count(), 1); @@ -2038,8 +2158,10 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), - op::Multiply(), op::Parameter(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Parameter()), GmockMatch(m::Multiply()), + GmockMatch(m::Parameter()), GmockMatch(m::Add()))); } TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { @@ -2065,9 +2187,69 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { module->schedule().is_computation_scheduled(module->entry_computation())); EXPECT_THAT( module->schedule().sequence(module->entry_computation()).instructions(), - ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), - op::Broadcast(), op::Multiply(), op::Add())); + ::testing::ElementsAre( + GmockMatch(m::Parameter()), GmockMatch(m::Parameter()), + GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()), + GmockMatch(m::Multiply()), GmockMatch(m::Add()))); +} + +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); } +TEST_F(HloParserTest, AllowShapeWhitespace) { + const string text = R"( +HloModule module + +ENTRY entry { + ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); +} + +TEST_F(HloParserTest, ShapeMismatchInOperand) { + const string text = R"( +HloModule foobar + +ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] { + %p = f32[2,2] parameter(0) + %constant.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}}) + ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1) +} +)"; + + ExpectHasSubstr(ParseHloString(text).status().error_message(), + "The declared operand shape f32[2,5]{1,0} is not compatible" + " with the shape of the operand instruction f32[2,2]{1,0}."); +} + +// custom call incompatible shape. + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5e004ce78ac1fd6da18ab2a54d23ef27e9586cf6..51177f24f5ee702be96fc8b4530ed38a5798109f 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -113,9 +113,10 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module, } const string message = - StrCat("after ", after_pass_name, ", before ", before_pass_name); + absl::StrCat("after ", after_pass_name, ", before ", before_pass_name); hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; + VLOG(3) << module.entry_computation_layout().ToString(); XLA_VLOG_LINES(3, module.ToString()); } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 09e7033ea4ed88849d2f3665d04f74f3f388b3f5..60d72b9d296d71f7bc2f1637bcbec1675513e5df 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -105,8 +105,6 @@ class HloPassPipeline : public HloPassInterface { std::vector> passes_; std::vector> invariant_checkers_; bool run_called_ = false; - - TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc index ee8cb12b231718e09f6ac0d05d7a6887f4c4d746..20384b9da6be4bab447b474f0e2240bcb277a620 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloPassPipelineTest : public HloVerifiedTestBase { +class HloPassPipelineTest : public HloTestBase { protected: StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index dcc22793015147aaf3229875078b2989e4ef7559..5eb707a957e49d86cdb2f72b72ce750bf29b8fd2 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" namespace xla { @@ -25,6 +26,11 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, string result; + for (const auto& item : hlo_profile_printer_data.extra_metrics()) { + absl::StrAppend(&result, "Extra metric ", item.first, ": ", + counters[item.second], "\n"); + } + for (const HloComputationInfo& computation_info : hlo_profile_printer_data.computation_infos()) { const auto& instruction_infos = computation_info.instruction_infos(); @@ -41,8 +47,9 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, // Once we start using this in AOT for real, we will probably need a more // minimal version of HumanReadableProfileBuilder. HumanReadableProfileBuilder builder( - computation_info.name(), counters[computation_info.profile_index()], - clock_rate_ghz); + computation_info.name(), + hlo_profile_printer_data.entry_computation() == computation_info.name(), + counters[computation_info.profile_index()], clock_rate_ghz); for (const auto& instruction_info : instruction_infos) { builder.AddOp( diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto index 9f22b733fe1d676b177039a9d7a3064b8638d7bc..ee66c86ffcb4fb74a24033e05f588a2f4d27dfe4 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto +++ b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto @@ -57,4 +57,10 @@ message HloProfilePrinterData { // The size of the profile counters array we will pretty-print. int64 profile_counters_size = 2; + + // Maps extra metric name to the index into the profile counters array. + map extra_metrics = 3; + + // Name of the entry computation. + string entry_computation = 4; } diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index b9c0b0c4ee1957fce48641230cef6391bcc9180e..981d06ce101644ecce587c4bd2f7a12c8edf6548 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include @@ -36,35 +37,47 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } -StatusOr> EntryComputationParameterShapes( +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } - std::vector parameter_shapes; - const auto& program_shape = hlo_proto.hlo_module().program_shape(); - for (const Shape& shape : program_shape.parameters()) { + std::vector parameter_shapes; + const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); + for (const ShapeProto& shape : program_shape.parameters()) { parameter_shapes.push_back(&shape); } return parameter_shapes; } -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { return NotFound("HloProto missing HloModuleProto."); } - if (!hlo_proto.hlo_module().has_program_shape()) { + if (!hlo_proto.hlo_module().has_host_program_shape()) { return NotFound("HloProto missing program shape."); } - if (!hlo_proto.hlo_module().program_shape().has_result()) { + if (!hlo_proto.hlo_module().host_program_shape().has_result()) { return NotFound("HloProto missing result in its program shape"); } - return &hlo_proto.hlo_module().program_shape().result(); + return &hlo_proto.hlo_module().host_program_shape().result(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 3d9c375cd5d26f92cf8316f78789daf4fc08c927..31ea2aaffd9cdb76d21edbd0d4a03aa5f865f4f0 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,14 +35,21 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Create an HLO state from serialized representation. In addition to +// creating the proto with HloModule::CreateFromProto(...) it also +// uses HloVerifier to ensure basic invariants are held. +StatusOr> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config); + // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. -StatusOr> EntryComputationParameterShapes( +StatusOr> EntryComputationParameterShapes( const HloProto& hlo_proto); // Returns the shape of the output of the entry computation. The shape pointer // refers to the output shape inside of the given HloProto. -StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); +StatusOr EntryComputationOutputShape( + const HloProto& hlo_proto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc243d955e136ccdf097c8155a115845..f968a4a94453f678f5c17e0b8d1df4aea70c93ea 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get({}); return true; } @@ -104,5 +104,20 @@ bool IsScalarConstant(const HloInstruction* instruction) { return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape()); } +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes) { + for (const auto* instr : comp->instructions()) { + if (opcodes.count(instr->opcode())) { + return true; + } + for (const HloComputation* subcomp : instr->called_computations()) { + if (ContainsInstrWithOpcode(subcomp, opcodes)) { + return true; + } + } + } + return false; +} + } // namespace hlo_query } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c0826a6aee1f693484207a86ec258c6604d92318..215051f8834fc94eb9e32b508f34b13626ac9349 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_ +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -41,6 +43,12 @@ bool AllOperandsAreConstants(const HloInstruction& instruction); // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); +// Determines whether the given computation contains an instruction with one of +// the given opcodes. Checks both comp's instructions and the instructions of +// any computations nested within it. +bool ContainsInstrWithOpcode(const HloComputation* comp, + const absl::flat_hash_set& opcodes); + // Returns an operand of an instruction with the given opcode. If there are // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 961930f0a888e90f86e4354fa1373a303af8ec2f..4aa8067752481ffab29e1a573ffa49d4aa046f1f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_reachability.h" namespace xla { @@ -22,7 +24,7 @@ HloReachabilityMap::HloReachabilityMap( : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { - indices_[hlo] = bit_vectors_.size(); + indices_[GetKey(hlo)] = bit_vectors_.size(); bit_vectors_.emplace_back(size_); } CHECK_EQ(size_, indices_.size()); // instructions should be unique @@ -71,4 +73,70 @@ bool HloReachabilityMap::IsConnected(const HloInstruction* a, return IsReachable(a, b) || IsReachable(b, a); } +std::unique_ptr HloReachabilityMap::Build( + const HloComputation* computation) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + auto channel_dependency_map = computation->ComputeChannelDependencies(); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.assign(hlo->operands().begin(), hlo->operands().end()); + inputs.insert(inputs.end(), hlo->control_predecessors().begin(), + hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + +void HloReachabilityMap::UpdateReachabilityThroughInstruction( + const HloInstruction* instruction) { + std::queue worklist; + worklist.push(instruction); + + std::vector inputs; + + while (!worklist.empty()) { + const HloInstruction* item = worklist.front(); + worklist.pop(); + + inputs.assign(item->operands().begin(), item->operands().end()); + inputs.insert(inputs.end(), item->control_predecessors().begin(), + item->control_predecessors().end()); + + if (SetReachabilityToUnion(inputs, item)) { + // Add immediate successors to worklist. + for (const HloInstruction* user : item->users()) { + worklist.push(user); + } + for (const HloInstruction* succ : item->control_successors()) { + worklist.push(succ); + } + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 5a5f01f8fd647c74217c80ce4a7633b8957e335f..7823b06a41b3052f6f50f7ffa358de5b23ba679f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -16,27 +16,30 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ +#include #include #include +#include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" namespace xla { -class HloInstruction; - // A class for representing reachability between HloInstructions. // -// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix -// and it is up to the user of the class to set the adjacency matrix such that -// it represents reachability, i.e. such that it is transitive. That the graph -// be transitive is thus not an invariant of this class, but it is required for -// the name of the class and its methods to make sense. +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given @@ -44,6 +47,15 @@ class HloReachabilityMap { explicit HloReachabilityMap( absl::Span instructions); + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true @@ -70,6 +82,10 @@ class HloReachabilityMap { // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + // Returns true if "b" is reachable from "a" // // Note that this function only correctly answers queries about reachability @@ -82,6 +98,11 @@ class HloReachabilityMap { // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + // Checks if an instruction is in the Reachability map. + bool IsPresent(const HloInstruction* a) const { + return indices_.contains(GetKey(a)); + } + private: // A bit-vector implementation specialized for this use case which provides a // fast bitwise OR operation not available in tensorflow::gtl::BitMap. @@ -143,18 +164,24 @@ class HloReachabilityMap { absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); + uint64 GetKey(const HloInstruction* instruction) const { + uint64 unique_id = absl::bit_cast(instruction->unique_id()); + uint64 module_id = + absl::bit_cast(instruction->parent()->parent()->unique_id()); + return (module_id << 32) | unique_id; + } // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { - return FindOrDie(indices_, instruction); + return FindOrDie(indices_, GetKey(instruction)); } // The number of instructions in the reachability map. const size_t size_; - // Dense assignment from HloInstruction* to number. These numbers index - // into the bit_vectors_ vector and into the bits within a BitVector. - absl::flat_hash_map indices_; + // Dense assignment from HloInstruction::unique_id to number. These numbers + // index into the bit_vectors_ vector and into the bits within a BitVector. + absl::flat_hash_map indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index d9848cee0bfa904a90aea4626c3ee62c2cbb45b6..595176709806d54fc7c7c5ea301654717096b2d6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloVerifiedTestBase {}; +class HloReachabilityTest : public HloTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: @@ -48,7 +48,8 @@ TEST_F(HloReachabilityTest, Reachability) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); auto e = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); HloReachabilityMap reachability({a, b, c, d, e}); reachability.SetReachable(a, a); @@ -81,6 +82,130 @@ TEST_F(HloReachabilityTest, Reachability) { EXPECT_FALSE(reachability.SetReachabilityToUnion({b, c}, d)); } +TEST_F(HloReachabilityTest, NonTrivialReachability) { + // Test reachability of a non-trivial computation: + // + // const1 const2 + // | | + // | +-------+ + // | | | + // add .. negate + // | . | + // | .... exp + // | | + // +---+ +-+---+ + // | | | + // multiply copy + // + // There is a control dependency from 'add' to 'exp'. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, constant1, constant2)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kNegate, constant2)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, add, exp)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32, HloOpcode::kCopy, exp)); + + auto module = CreateNewVerifiedModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); + + TF_CHECK_OK(add->AddControlDependencyTo(exp)); + auto reachability = HloReachabilityMap::Build(computation); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_TRUE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_TRUE(reachability->IsReachable(constant1, copy)); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_TRUE(reachability->IsReachable(constant2, negate)); + EXPECT_TRUE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_TRUE(reachability->IsReachable(constant2, copy)); + + EXPECT_FALSE(reachability->IsReachable(exp, constant1)); + EXPECT_FALSE(reachability->IsReachable(exp, constant2)); + EXPECT_FALSE(reachability->IsReachable(exp, add)); + EXPECT_FALSE(reachability->IsReachable(exp, negate)); + EXPECT_TRUE(reachability->IsReachable(exp, exp)); + EXPECT_TRUE(reachability->IsReachable(exp, mul)); + EXPECT_TRUE(reachability->IsReachable(exp, copy)); + + EXPECT_FALSE(reachability->IsReachable(mul, constant1)); + EXPECT_FALSE(reachability->IsReachable(mul, constant2)); + EXPECT_FALSE(reachability->IsReachable(mul, add)); + EXPECT_FALSE(reachability->IsReachable(mul, negate)); + EXPECT_FALSE(reachability->IsReachable(mul, exp)); + EXPECT_TRUE(reachability->IsReachable(mul, mul)); + EXPECT_FALSE(reachability->IsReachable(mul, copy)); + + EXPECT_TRUE(reachability->IsConnected(constant1, copy)); + EXPECT_TRUE(reachability->IsConnected(copy, constant1)); + EXPECT_FALSE(reachability->IsConnected(negate, add)); + EXPECT_FALSE(reachability->IsConnected(add, negate)); + + // Remove the control dependency then update and verify the reachability map + ASSERT_IS_OK(add->RemoveControlDependencyTo(exp)); + reachability->UpdateReachabilityThroughInstruction(exp); + + EXPECT_TRUE(reachability->IsReachable(constant1, constant1)); + EXPECT_FALSE(reachability->IsReachable(constant1, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant1, add)); + EXPECT_FALSE(reachability->IsReachable(constant1, negate)); + EXPECT_FALSE(reachability->IsReachable(constant1, exp)); + EXPECT_TRUE(reachability->IsReachable(constant1, mul)); + EXPECT_FALSE(reachability->IsReachable(constant1, copy)); + + // Change a use within the graph then update and verify the reachability map + ASSERT_IS_OK(constant2->ReplaceUseWith(negate, constant1)); + reachability->UpdateReachabilityThroughInstruction(negate); + + EXPECT_FALSE(reachability->IsReachable(constant2, constant1)); + EXPECT_TRUE(reachability->IsReachable(constant2, constant2)); + EXPECT_TRUE(reachability->IsReachable(constant2, add)); + EXPECT_FALSE(reachability->IsReachable(constant2, negate)); + EXPECT_FALSE(reachability->IsReachable(constant2, exp)); + EXPECT_TRUE(reachability->IsReachable(constant2, mul)); + EXPECT_FALSE(reachability->IsReachable(constant2, copy)); +} + +TEST_F(HloReachabilityTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = HloReachabilityMap::Build(computation); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 5ac43808ee2945eaa5003baad24d5d331419db83..48add75523f02005c70bc6baf69a6b7d5aa4f7ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -130,10 +130,10 @@ using ItemList = absl::InlinedVector; // before arbitrary elements. class InstructionList { public: - explicit InstructionList(const std::vector& order) { + explicit InstructionList(const HloInstructionSequence& order) { int64 position = 0; Item* last = nullptr; - for (const HloInstruction* inst : order) { + for (HloInstruction* inst : order.instructions()) { // Add a new item to the linked list. Item* item = new Item; item->next = nullptr; @@ -151,7 +151,7 @@ class InstructionList { // to be monotonically increasing through the list, and so is still useful // for quickly(-ish) determining the order of arbitrary instructions in // the list. - item->instruction = const_cast(inst); + item->instruction = inst; item->position = position; position++; @@ -927,7 +927,7 @@ Item* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, - const std::vector& order) const { + const HloInstructionSequence& order) const { InstructionList instruction_list(order); MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, instruction_list); @@ -971,8 +971,7 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list( - schedule->sequence(computation).instructions()); + InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1184,7 +1183,7 @@ StatusOr HloRematerialization::RematerializeComputation( sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { - const HloInstruction* instruction = item->instruction; + HloInstruction* instruction = item->instruction; sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1215,7 +1214,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { // by the caller. int64 module_output_size = 0; ShapeUtil::ForEachSubshape( - module->entry_computation()->root_instruction()->shape(), + module->result_shape(), [&module_output_size, this](const Shape& subshape, const ShapeIndex& /*index*/) { module_output_size += size_function_(subshape); @@ -1235,10 +1234,8 @@ StatusOr HloRematerialization::Run(HloModule* module) { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - module->schedule() - .sequence(node.computation()) - .instructions())); + ComputePeakMemory(node.computation(), module->schedule().sequence( + node.computation()))); } return Status::OK(); }, diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 70d83c04f07ca7fd0139f586869e8fe688f958f4..a07d348041b72bba45c6fd1f726f2a0065d01e53 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -87,9 +87,8 @@ class HloRematerialization : public HloModulePass { // peak memory is the maximum total size of all live HLO instruction values at // any program point. 'order' is the order in which the HLO instructions will // be emitted which is used to determine lifespans of HLO values. - StatusOr ComputePeakMemory( - const HloComputation* computation, - const std::vector& order) const; + StatusOr ComputePeakMemory(const HloComputation* computation, + const HloInstructionSequence& order) const; // Returns the peak memory usage of the called computations for the given // instruction. Zero is returned if the instruction calls no computations. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f7e82fb1f88e856305f6f481a451d4cd64ba4acf..22c3c40a93a1ddcd36659483fcc79fede32dd2c3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloVerifiedTestBase { +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -162,7 +162,7 @@ class HloRematerializationTest : public HloVerifiedTestBase { // Test rematerialization of a single computation produced by // MakeRematerializableComputation. TEST_F(HloRematerializationTest, SingleComputation) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -177,7 +177,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, module)); + /*memory_limit_bytes=*/14 * 1024, module.get())); EXPECT_TRUE(changed); // Root should not have changed. @@ -203,7 +203,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); @@ -211,7 +211,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, module)); + /*memory_limit_bytes=*/20 * 1024, module.get())); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -225,7 +225,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. TEST_F(HloRematerializationTest, RematerializeAroundWhile) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -249,7 +249,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // bit lower (17KB) to force rematerialization of the entry computation. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, module)); + /*memory_limit_bytes=*/17 * 1024, module.get())); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -261,7 +261,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while. Both the entry computation and while body computation should have // computations rematerialized. TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -282,7 +282,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, module)); + /*memory_limit_bytes=*/15 * 1024, module.get())); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -293,7 +293,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. TEST_F(HloRematerializationTest, RematerializeNestedComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); cond_builder.AddInstruction( @@ -321,7 +321,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // ~12K so pick something slightly larger. TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, module)); + /*memory_limit_bytes=*/13 * 1024, module.get())); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -346,7 +346,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + // // rng + tanh + exp - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( @@ -390,7 +390,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module.get())); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -420,7 +420,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // The value %bcast is live across each call of Subcomputation (which requires // 8KB) though the value is not used in the calls. Rematerializing %bcast // across these calls reduces peak memory use from ~20KB down to ~16KB. - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -482,7 +482,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -533,7 +533,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // (ie %bcast is used indirectly by %negate), otherwise the %negate operand // aliases %add_2. const bool indirectly_used = GetParam(); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloComputation* subcomputation = nullptr; { @@ -576,7 +576,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, module)); + /*memory_limit_bytes=*/22 * 1024, module.get())); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index fa7f216321988137dcf9104a324f5f7789869aa5..5a9b820a9d7f58695383b21c9e2126cf98970c83 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -204,6 +205,40 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer retval, + executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments)); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + return std::move(retval); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return ExecuteWithDeviceBuffers( + /*executable=*/std::move(executable), + /*arguments=*/argument_pointers, + /*profile=*/profile); +} + StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { @@ -324,10 +359,13 @@ StatusOr> HloRunner::ExecuteReplicated( StatusOr> HloRunner::CreateExecutable( std::unique_ptr module, bool run_hlo_passes) { if (run_hlo_passes) { + auto module_group = absl::make_unique(std::move(module)); TF_ASSIGN_OR_RETURN( - module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor(), - backend().memory_allocator())); + auto executables, + backend().compiler()->Compile(std::move(module_group), + {{backend().default_stream_executor()}}, + backend().memory_allocator())); + return std::move(executables[0]); } return backend().compiler()->RunBackend(std::move(module), backend().default_stream_executor(), diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 2e934bf66ae43ea412f242030b874dddb6d3722d..bb792cf8c9825ff67ca33bbcf2c3c32b1a0ecb85 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -136,6 +136,21 @@ class HloRunner { const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr executable, + const absl::Span arguments, + ExecutionProfile* profile = nullptr); + + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run as part of compilation. + StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes); + // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. @@ -152,11 +167,6 @@ class HloRunner { const Backend& backend() const; private: - // Creates an executable object given an HLO module. If run_hlo_passes is - // true, the HLO passes will be run before. - StatusOr> CreateExecutable( - std::unique_ptr module, bool run_hlo_passes); - // Creates a ServiceExecutableRunOptions object to configure a run on device, // using the provided stream object. If device_assignment is not nullptr, it // will be used to configure the replication parameters. Replicated executions diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 9972eb20774550817143cb27dd94667364cf68ec..8f6eb974c5179b420c8f961393ca923e0a3b3530 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -46,8 +46,8 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -81,9 +81,8 @@ StatusOr HloSchedule::ToProto() const { return std::move(proto); } -void HloSchedule::set_sequence( - const HloComputation* computation, - absl::Span sequence) { +void HloSchedule::set_sequence(const HloComputation* computation, + absl::Span sequence) { set_sequence(computation, HloInstructionSequence(sequence)); } @@ -114,8 +113,8 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - absl::flat_hash_map id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { + absl::flat_hash_map id_to_instruction; + for (HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } @@ -128,7 +127,7 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - absl::flat_hash_map> + absl::flat_hash_map> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's @@ -138,9 +137,9 @@ Status HloSchedule::UpdateComputationSchedule( // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; + std::queue worklist; - for (const HloInstruction* instruction : computation->instructions()) { + for (HloInstruction* instruction : computation->instructions()) { if (ids_in_schedule.count(instruction->unique_id()) == 0) { // This is a newly added instruction which is not in the schedule. if (instruction->operands().empty()) { @@ -161,17 +160,17 @@ Status HloSchedule::UpdateComputationSchedule( // Lambda which schedules all instructions on the worklist. auto schedule_worklist = [&]() { while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); + HloInstruction* instruction = worklist.front(); worklist.pop(); new_sequence.push_back(instruction); - std::vector* new_users = + std::vector* new_users = tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); if (new_users != nullptr) { // This just-scheduled instruction has users which are newly added to // the module. Update the number of unscheduled operands and push the // newly added instruction to the worklist if it is ready to // schedule. - for (const HloInstruction* new_user : *new_users) { + for (HloInstruction* new_user : *new_users) { unscheduled_operand_count.at(new_user)--; CHECK_GE(unscheduled_operand_count.at(new_user), 0); if (unscheduled_operand_count.at(new_user) == 0) { @@ -235,7 +234,6 @@ Status HloSchedule::Update() { Status HloSchedule::Verify() const { VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(3, module_->ToString()); XLA_VLOG_LINES(2, ToString()); // Verify schedule contains exactly the same set of non-fusion computations as @@ -265,7 +263,10 @@ Status HloSchedule::Verify() const { } TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); + computation->instruction_count()) + << "Schedule for computation " << computation->name() << " has " + << instruction_position.size() << " instructions, expected " + << computation->instruction_count(); for (const HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(instruction_position.count(instruction) == 1) << "Instruction " << instruction->name() << " is not in schedule"; diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 0a714101ee587aa847fa674bbde5586287c51f33..486ddbf499de80c634bc497158cd79ca066cc866 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -35,14 +35,14 @@ class HloInstructionSequence { public: HloInstructionSequence() = default; explicit HloInstructionSequence( - absl::Span instructions) { - for (const HloInstruction* instruction : instructions) { + absl::Span instructions) { + for (HloInstruction* instruction : instructions) { push_back(instruction); } } // Adds the instruction to the end of the sequence. - void push_back(const HloInstruction* instruction) { + void push_back(HloInstruction* instruction) { instruction_sequence_.push_back(instruction); id_sequence_.push_back(instruction->unique_id()); } @@ -56,7 +56,7 @@ class HloInstructionSequence { int64 size() const { return instruction_sequence_.size(); } // Returns the sequence of HLO instructions. - const std::vector& instructions() const { + const std::vector& instructions() const { return instruction_sequence_; } @@ -65,7 +65,7 @@ class HloInstructionSequence { private: // The sequence as HloInstructions. - std::vector instruction_sequence_; + std::vector instruction_sequence_; // The sequence of HLO instructions, represented by their unique IDs. The // sequence is stored as both HloInstructions and unique IDs because the @@ -98,7 +98,7 @@ class HloSchedule { // Sets the sequence for the given computation to the given sequence. void set_sequence(const HloComputation* computation, - absl::Span sequence); + absl::Span sequence); void set_sequence(const HloComputation* computation, HloInstructionSequence sequence); diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc index 1424569ac1f62e4b965876141f1eb40be4f15bea..0e56e6f760e35ddcb45c6f58771d78405a09acfe 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -56,10 +56,10 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); - const std::vector& entry_schedule = + const auto& entry_schedule = schedule.sequence(module->entry_computation()).instructions(); EXPECT_EQ(entry_schedule.size(), 6); @@ -90,7 +90,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -139,7 +139,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -183,7 +183,7 @@ ENTRY main { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); @@ -244,7 +244,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); @@ -313,7 +313,7 @@ ENTRY %WhileLoop () -> s32[] { ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN( HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + ScheduleModule(module.get(), [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/sizeof(void*)); })); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 188f4acc7945f3ec98065eae5a87a41c39730432..70a860c356ca2fb1c4c973ea3d96c50fabc2c7c2 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -469,6 +469,9 @@ absl::optional HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } + if (tuple_elements_.empty()) { + return absl::nullopt; + } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { return absl::nullopt; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index e3f4a9852ace86c20610362aa6ad3c3d9c78de30..f5061304456e04ab40448861343ef201c9450dcf 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -169,14 +169,14 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, // If user is a tuple instruction, return the tuple subsharding corresponding to // the operand matching the instruction argument, because that is the // subsharding corresponding to instruction. -ShapeTree GetShardingTreeFromUser( +StatusOr> GetShardingTreeFromUser( const HloInstruction& instruction, const HloInstruction& user) { if (user.opcode() == HloOpcode::kTuple) { return user.sharding() .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) - .GetAsShapeTree(instruction.shape()); + .AsShapeTree(instruction.shape()); } - return user.sharding().GetAsShapeTree(user.shape()); + return user.sharding().AsShapeTree(user.shape()); } // Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) @@ -253,7 +253,7 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); for (HloInstruction* user : instruction->users()) { if (user->opcode() == HloOpcode::kDomain && - domain.exit_domains.count(const_cast(user)) > 0) { + domain.exit_domains.count(user) > 0) { // If a user is a domain and it is registered in the domain exits, then // the instruction sharding is taken directly from the domain, and no // further users need to be visited. @@ -264,8 +264,8 @@ StatusOr ApplyShardingFromUsers(HloInstruction* instruction, continue; } AssignmentKind sub_assigned = AssignmentKind::kUnassigned; - ShapeTree user_sharding_tree = - GetShardingTreeFromUser(*instruction, *user); + TF_ASSIGN_OR_RETURN(ShapeTree user_sharding_tree, + GetShardingTreeFromUser(*instruction, *user)); if (ShapeUtil::IsTuple(instruction->shape())) { // For tuple-shaped instructions collect individual tuple subshardings // from the uses, and then combine them into the tuple sharding. diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 45c684d66752862eec301b8943d350804f070309..c1073911ea9dc3811c195e27bcbae9b00929ad17 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -66,7 +66,7 @@ class HloSubcomputationUnificationTest : public HloTestBase { }; TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -103,7 +103,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { } TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -143,7 +143,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { // Do not unify subcomputations with different parameter shapes. TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto callee1 = @@ -184,7 +184,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { // Regression test for b/31466798. Checks that entry_computation is still valid // after unification. TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); for (int i = 0; i < 2; ++i) { HloComputation::Builder builder("pow"); auto x = diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f..1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloVerifiedTestBase { +class HloTfGraphBuilderTest : public HloTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index b6670d409b92e8be42f5cdb40fba8d662ae83958..1f01b0bb365450a933da9cc443db5223c06903f0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -166,9 +166,6 @@ class HloValue : public BufferValue { // Whether this value is live out of the HLO module. bool live_out_of_module_ = false; - - // Whether this value is live out of its computation. - bool live_out_of_computation_ = false; }; std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index b5498bb936cb422d7a9dfa3d647266fa8b024b97..77db7b098a38ff4efdcc7447935fae61561c9ff4 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -27,6 +28,68 @@ limitations under the License. namespace xla { +Status VerifyNotSparse(const Shape& shape) { + return ShapeUtil::ForEachSubshapeWithStatus( + shape, [](const Shape& subshape, const ShapeIndex&) -> Status { + if (LayoutUtil::IsSparseArray(subshape)) { + return InternalError("Sparse arrays are not yet fully supported: %s", + ShapeUtil::HumanStringWithLayout(subshape)); + } + return Status::OK(); + }); +} + +bool IsCallerInstruction(HloInstruction* hlo) { + switch (hlo->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kWhile: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + return true; + default: + return false; + } +} + +Status ShapeVerifier::Preprocess(HloInstruction* hlo) { + if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) { + return InternalError( + "Called computations specified for non-caller instruction %s", + hlo->ToString()); + } + return VerifyNotSparse(hlo->shape()); +} + +namespace { + +Status CheckOperandCount(const HloInstruction* hlo, int expected) { + if (hlo->operand_count() != expected) { + return InternalError("Expected %d operands for %s instruction: %s", + expected, HloOpcodeString(hlo->opcode()), + hlo->ToString()); + } + return Status::OK(); +} + +Status CheckParameterCount(const HloInstruction* calling_instruction, + const HloComputation* computation, int expected) { + if (computation->num_parameters() != expected) { + return InternalError( + "Expected computation %s called from %s to have %d parameters, has %d", + computation->name(), calling_instruction->name(), expected, + computation->num_parameters()); + } + return Status::OK(); +} + +} // namespace + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -58,12 +121,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -74,6 +139,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -82,6 +148,7 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( @@ -92,6 +159,7 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -118,11 +186,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { } Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -156,6 +226,7 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -166,6 +237,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -192,10 +264,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - if (instruction->operand_count() != 2) { - return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString()); - } + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); @@ -244,29 +313,42 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()), - StringifyShape(sort->operand(1)->shape())); + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); + } + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); + if (!Cast(constant)->HasLiteral()) { + return InternalError("Constant is required to have a valid literal: %s", + constant->ToString()); + } return CheckShape(constant, constant->literal().shape()); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast(instruction); const int64 rank = ShapeUtil::Rank(iota->shape()); if (rank == 0) { @@ -281,6 +363,7 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), @@ -288,6 +371,12 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (reduce->operand_count() % 2 != 0) { + return InternalError( + "Expected an even number of operands for %s instruction: %s", + HloOpcodeString(reduce->opcode()), reduce->ToString()); + } + std::vector operand_shapes; for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); @@ -298,48 +387,64 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); // Check for mixed precision. - TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); + TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape)); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); for (int64 operand_dimension = 0; operand_dimension < ShapeUtil::Rank(operand_shape); ++operand_dimension) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)) + TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && + output_dimension >= 0 && + (broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension))) << broadcast->ToString() << " operand shape " << operand_shape; } return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. - TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); + const Shape& operand_shape = reshape->operand(0)->shape(); + TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape)); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == - ShapeUtil::ElementsIn(reshape->operand(0)->shape())); + ShapeUtil::ElementsIn(operand_shape)); return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { - for (HloInstruction* fused_param : fusion->fused_parameters()) { + auto& fused_parameters = fusion->fused_parameters(); + if (fused_parameters.size() != fusion->operand_count()) { + return InternalError( + "Fused parameter count (%d) does not match the number of operands (%d)" + " passed to the fusion instruction in: %s.", + fused_parameters.size(), fusion->operand_count(), + fusion->ToString().c_str()); + } + for (HloInstruction* fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { return InternalError( @@ -352,6 +457,8 @@ Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { } Status ShapeVerifier::HandleCall(HloInstruction* call) { + TF_RETURN_IF_ERROR( + CheckParameterCount(call, call->to_apply(), call->operand_count())); for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); } @@ -359,9 +466,30 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + TF_RET_CHECK(custom_call != nullptr); + if (custom_call->layout_constrained()) { + // If the layout is constrained, verify all the respective shapes have + // layouts and that the constrained operand shapes match the shapes of the + // operands. + TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape())); + TF_RET_CHECK(custom_call->operand_count() == + custom_call->operand_shapes_with_layout().size()); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + const Shape& operand_shape_with_layout = + custom_call->operand_shapes_with_layout()[i]; + TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), + operand_shape_with_layout)); + TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -369,6 +497,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), @@ -377,6 +506,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -406,6 +536,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -415,6 +546,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -425,6 +557,11 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(xla_while, xla_while->while_body(), 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(xla_while, xla_while->while_condition(), 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -444,6 +581,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); + TF_RETURN_IF_ERROR( + CheckParameterCount(conditional, conditional->true_computation(), 1)); + TF_RETURN_IF_ERROR( + CheckParameterCount(conditional, conditional->false_computation(), 1)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -458,12 +600,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { + TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -471,10 +615,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -482,6 +628,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -491,6 +638,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -501,6 +649,7 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -512,6 +661,7 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -548,6 +698,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -579,6 +730,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { + TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -587,6 +739,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -600,7 +753,19 @@ Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { for (const HloInstruction* operand : token->operands()) { operand_shapes.push_back(&operand->shape()); } - return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes)); + return CheckShape(token, ShapeUtil::MakeTokenShape()); +} + +Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) { + TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2)); + TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1)); + return CheckShape(add_dependency, add_dependency->operand(0)->shape()); +} + +Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { + return CheckShape(get_size, + ShapeInference::InferGetDimensionSizeShape( + get_size->operand(0)->shape(), get_size->dimension())); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -674,12 +839,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -687,6 +854,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -699,6 +867,50 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } +Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) { + const HloComputation* computation = module.entry_computation(); + const auto& layout = module.entry_computation_layout(); + const ShapeLayout& result_layout = layout.result_layout(); + + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape())); + + TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape())); + + if (!ShapeUtil::Compatible(computation->root_instruction()->shape(), + result_layout.shape())) { + return InternalError( + "Shape of the root instruction of entry computation (%s) should be " + "compatible to one specified in module's entry computation layout (%s)", + ShapeUtil::HumanString(computation->root_instruction()->shape()), + ShapeUtil::HumanString(result_layout.shape())); + } + + if (computation->num_parameters() != layout.parameter_count()) { + return InternalError( + "Number of parameters in entry computation layout (%d) must be same " + "as number of parameters of entry computation computation (%d)", + layout.parameter_count(), computation->num_parameters()); + } + + for (int i = 0; i < computation->num_parameters(); ++i) { + const HloInstruction* parameter = computation->parameter_instruction(i); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i))); + TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i))); + if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) { + return InternalError( + "Shape of the entry computation parameter %d is %s should be " + "compatible to the one specified in module's entry computation " + "layout %s", + i, ShapeUtil::HumanString(parameter->shape()), + ShapeUtil::HumanString(layout.parameter_shape(i))); + } + } + + return Status::OK(); +} + string ComputationsToString(absl::Span computations) { return absl::StrJoin(computations, ",", [](string* s, const HloComputation* computation) { @@ -1041,7 +1253,10 @@ Status CheckElementwiseInstruction(HloInstruction* instruction) { // not check result shape as that is checked in the ShapeVerifier. class InstructionVerifier : public DfsHloVisitorWithDefault { public: - InstructionVerifier() {} + explicit InstructionVerifier(std::function + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} Status DefaultAction(HloInstruction*) override { return Status::OK(); } @@ -1129,6 +1344,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleCrossReplicaSum(HloInstruction* crs) override { + if (crs->all_reduce_id().has_value()) { + TF_RET_CHECK(crs->all_reduce_id().value() > 0) + << "All reduce id must be greater than 0 for " + << crs->ToShortString(); + } + return Status::OK(); + } + Status Preprocess(HloInstruction* instruction) override { auto previous = instructions_by_name_.find(instruction->name()); TF_RET_CHECK(previous == instructions_by_name_.end()) @@ -1142,26 +1366,59 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { return Status::OK(); } + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " << result_shape << " To " + << operand_shape; + } + } + } + + return Status::OK(); + } + private: absl::flat_hash_map instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(!module->name().empty()); + + if (module->entry_computation()->IsFusionComputation()) { + return InvalidArgument( + "Module entry computation cannot be a fusion computation"); + } + TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - + std::unique_ptr shape_verifier = + target_metadata_->GetVerifier(); for (auto* computation : module->computations()) { - std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); - InstructionVerifier instruction_verifier; + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } + TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module)); TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); // If the module has a schedule, it must be valid. @@ -1169,6 +1426,13 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(module->schedule().Verify()); } + TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify( + *module, [this](const Shape& shape) { + return target_metadata_->ShapeSize(shape); + })); + + TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6d16586c2c062d407e37392e3fe50be4fd29120b..e4d0c3d6957885f1d719fedb5a900de601e397f8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#include #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "absl/memory/memory.h" @@ -28,10 +29,16 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) : layout_sensitive_(layout_sensitive), allow_mixed_precision_(allow_mixed_precision) {} + // Verifies that entry computation layout matches parameters and root shape of + // the module's entry computation. + virtual Status VerifyEntryComputationLayout(const HloModule& module); + + Status Preprocess(HloInstruction* hlo) override; + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -87,6 +94,8 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleGather(HloInstruction* gather) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; + Status HandleGetDimensionSize(HloInstruction* get_size) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } @@ -121,6 +130,13 @@ class ShapeVerifier : public DfsHloVisitor { : ShapeUtil::HumanString(s); } + // Helpers that switch on allow_mixed_precision_. + bool SameElementType(const Shape& a, const Shape& b) { + return allow_mixed_precision_ + ? ShapeUtil::SameElementTypeIgnoringFpPrecision(a, b) + : ShapeUtil::SameElementType(a, b); + } + // Checks that the given operand of the given instruction is of type TOKEN. Status CheckIsTokenOperand(const HloInstruction* instruction, int64 operand_no); @@ -149,21 +165,64 @@ class ShapeVerifier : public DfsHloVisitor { bool allow_mixed_precision_; }; +// An interface used to encapsulate target-specific verification quirks. +class TargetVerifierMetadata { + public: + // Returns a target-specific shape size. + virtual int64 ShapeSize(const Shape& shape) const = 0; + + virtual std::unique_ptr GetVerifier() const = 0; + + TargetVerifierMetadata() {} + virtual ~TargetVerifierMetadata() {} + + TargetVerifierMetadata(const TargetVerifierMetadata&) = delete; + TargetVerifierMetadata& operator=(const TargetVerifierMetadata&) = delete; +}; + +// The default implementation of TargetVerifierMetadata, used unless the target +// needs to override it. +class DefaultVerifierMetadata : public TargetVerifierMetadata { + public: + DefaultVerifierMetadata(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} + + int64 ShapeSize(const Shape& shape) const override { + return ShapeUtil::ByteSizeOf(shape); + } + + // Creates a ShapeVerifier that checks that shapes match inferred + // expectations. This creates a new verifier every time because ShapeVerifier, + // being a DfsHloVisitor, is stateful. We want a clean object for each run of + // the verifier. + std::unique_ptr GetVerifier() const override { + return absl::make_unique(layout_sensitive_, + allow_mixed_precision_); + } + + private: + bool layout_sensitive_; + bool allow_mixed_precision_; +}; + // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloModulePass { public: - using ShapeVerifierFactory = std::function()>; - - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) - : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { - return absl::make_unique(layout_sensitive, - allow_mixed_precision); - }) {} + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function + instruction_can_change_layout_func = {}) + : target_metadata_(absl::make_unique( + layout_sensitive, allow_mixed_precision)), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } - // Uses custom shape verification. - explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) - : shape_verifier_factory_(std::move(shape_verifier_factory)) {} + // Uses custom target metadata + explicit HloVerifier(std::unique_ptr target_metadata) + : target_metadata_(std::move(target_metadata)) {} ~HloVerifier() override = default; absl::string_view name() const override { return "verifier"; } @@ -172,11 +231,11 @@ class HloVerifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - // Creates a ShapeVerifier that checks that shapes match inferred - // expectations. This is a factory function because ShapeVerifier, - // being a DfsHloVisitor, is stateful. We want a clean object - // for each run of the verifier. - ShapeVerifierFactory shape_verifier_factory_; + std::unique_ptr target_metadata_; + + // Determines whether an instruction can change layouts. + std::function + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c72ceb209437116a898d027f4d2c657..4bc557e4e62e7df4e25fda86fe417e84129b464c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -34,7 +35,11 @@ namespace { using ::testing::HasSubstr; -// This class cannot be converted to use HloVerifiedTestBase. It explicitly +std::unique_ptr CreateUnverifiedModule() { + return absl::make_unique("module", HloModuleConfig()); +} + +// This class cannot be converted to use HloTestBase. It explicitly // uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: @@ -50,6 +55,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -57,7 +70,7 @@ TEST_F(HloVerifierTest, NullInstructionParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -76,7 +89,7 @@ TEST_F(HloVerifierTest, NullComputationParent) { HloInstruction::CreateParameter(0, scalar_shape, "param")); builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); TF_ASSERT_OK(verifier().Run(module.get()).status()); @@ -95,7 +108,7 @@ TEST_F(HloVerifierTest, DifferentOperandParents) { HloInstruction::CreateParameter(0, scalar_shape, "param")); HloInstruction* negate = builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); HloComputation::Builder emb_builder(TestName()); @@ -129,7 +142,7 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) { builder.AddInstruction( HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); // Run the verifier twice. It should fail both times, because it shouldn't @@ -294,7 +307,7 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); auto status = verifier().Run(module.get()).status(); @@ -318,7 +331,7 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); - auto module = CreateNewModule(); + auto module = CreateUnverifiedModule(); module->AddEntryComputation(builder.Build()); EXPECT_THAT(verifier().Run(module.get()).status().error_message(), @@ -358,5 +371,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index e76b93107c923b41666f6b0a388dda143a8cb50a..90904ac00110457bcc3b8974816a7080c4ab89fc 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -90,20 +90,29 @@ string HumanReadableProfileBuilder::ToString() const { op.optimal_seconds < 0 ? "" : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), - op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), - op.transcendental_count <= 0 - ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + op.flop_count > 0 && nsecs > 0 + ? HumanReadableNumFlops(op.flop_count, nsecs) + : "", + op.transcendental_count > 0 && nsecs > 0 + ? HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) + : "", bytes_per_sec, bytes_per_cycle, op.name); }; - float optimal_seconds_sum = 0.0; + double optimal_seconds_sum = 0; int64 total_flops = 0.; int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { if (op.optimal_seconds > 0) { - optimal_seconds_sum += op.optimal_seconds; + // An op can run faster than the estimated optimum. For example, we might + // estimate a fusion's speed by looking at the size of its operands and + // result, but perhaps the fusion doesn't read the entirety of all of its + // inputs. For the purposes of summing the instructions' optimal speeds, + // we treat the "optimum" as the smallest of either the estimated optimum + // and the actual speed. + optimal_seconds_sum += + std::min(double{op.optimal_seconds}, CyclesToSeconds(op.cycles)); } total_flops += std::max(op.flop_count, int64{0}); total_transcendentals += std::max(op.transcendental_count, int64{0}); @@ -112,8 +121,9 @@ string HumanReadableProfileBuilder::ToString() const { VLOG(1) << "Total floating point ops: " << total_flops; - print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}, + print_op({is_entry_computation_ ? "[total] [entry]" : "[total]", "[total]", + /*category=*/"", total_cycles_, total_flops, total_transcendentals, + total_bytes, static_cast(optimal_seconds_sum)}, /*is_total=*/true); // Sort ops in decreasing order of cycles, and print them. @@ -154,8 +164,10 @@ string HumanReadableProfileBuilder::ToString() const { entry.text = op.name; entry.short_text = op.short_name; entry.category_text = op.category; - entry.metric = - CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6; + // Ignore ops that run faster than the estimated optimal here, as we do + // when calculating optimal_seconds_sum. + entry.metric = std::max( + 0., CyclesToMicroseconds(op.cycles) - op.optimal_seconds * 1e6); total_discrepancy_in_microseconds += entry.metric; table.AddEntry(std::move(entry)); } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 925111fa1f1e48650b0089f402d92e431043eabe..d4e5cbbe27418ddf3c81ebe00bc8aa979d3c2d5e 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -30,9 +30,11 @@ namespace xla { class HumanReadableProfileBuilder { public: explicit HumanReadableProfileBuilder(absl::string_view computation_name, + bool is_entry_computation, int64 total_cycles, double clock_rate_ghz) : computation_name_(computation_name), + is_entry_computation_(is_entry_computation), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -75,6 +77,7 @@ class HumanReadableProfileBuilder { } string computation_name_; + bool is_entry_computation_; int64 total_cycles_; double clock_rate_ghz_; std::vector op_infos_; diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d5225b8012b68f851b2bfec219d736ba0d..cf6cf897fe11eda01ba6b22119bba34ac2bef8fe 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -18,19 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { +class ImplicitBroadcastRemoverTest : public HloTestBase { protected: ImplicitBroadcastRemover remover_; }; TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -41,15 +42,16 @@ TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + EXPECT_FALSE(remover_.Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Parameter(), op::Parameter())); } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); @@ -60,13 +62,13 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); @@ -76,6 +78,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -86,9 +89,9 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Subtract(op::Parameter(), @@ -98,6 +101,7 @@ TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { } TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); @@ -108,9 +112,9 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { builder.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kSubtract, param0, param1)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, @@ -120,6 +124,7 @@ TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { } TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); @@ -132,9 +137,9 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), @@ -147,6 +152,7 @@ TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { TEST_F(ImplicitBroadcastRemoverTest, TernaryScalarAndDegenerateDimensionBroadcast) { + auto m = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); @@ -159,9 +165,9 @@ TEST_F(ImplicitBroadcastRemoverTest, builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, param0, param1, param2)); - HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); - EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + EXPECT_TRUE(remover_.Run(m.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 2d03aebc1aca4c55cca588072233b7a18e70a306..98246d5403e4aebc2f4d81e52145706355ddd9a9 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloVerifiedTestBase { +class IndexedArrayAnalysisTest : public HloTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -61,12 +61,12 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { const string& root_expression, bool print_constants) { IndexedArrayAnalysis indexed_tensor_analysis; - ParseAndVerifyModule(hlo_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN( - IndexedArrayAnalysis::Array* const array_result, - indexed_tensor_analysis.GetArrayFor( - module().entry_computation()->root_instruction())); + TF_ASSERT_OK_AND_ASSIGN(IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + m->entry_computation()->root_instruction())); string string_result = CanonicalizeWhitespace( indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; @@ -481,8 +481,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { - { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, - { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ { 1, 2, 3, 4, 5, 6 } } } } }) (reshape %indices to s32[]) 0->[]) )"; @@ -512,8 +512,8 @@ ENTRY main { const char* expected_root_expression = R"( (scalar-indexed-const (constant s32[2,1,1,6] s32[2,1,1,6] { - { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, - { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + { /*i0=0*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } }, + { /*i0=1*/ { /*i1=0*/ { 1, 2, 3, 4, 5, 6 } } } }) (reshape %indices to s32[5]) 0->[2]) )"; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 69a4c160ee5c4539272c3085338dc6de1b9347ff..7559ed1bab84b21a4d51bc38db999900befcfad7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -26,7 +26,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/fusion_queue.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -101,7 +103,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: - case HloOpcode::kAfterAll: case HloOpcode::kTranspose: case HloOpcode::kTuple: case HloOpcode::kTupleSelect: @@ -114,7 +115,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kSin: return ShapeUtil::ElementIsComplex(instruction.shape()); - // Expensive instructions. + // Expensive instructions or unusual instructions for which fusion is + // nonsensical. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: case HloOpcode::kAtan2: case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: @@ -153,6 +157,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGetDimensionSize: return true; } @@ -437,8 +442,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { } // namespace std::unique_ptr InstructionFusion::GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer) { + HloComputation* computation) { return absl::make_unique(computation); } @@ -451,14 +455,16 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; - reachability_ = computation_->ComputeReachability(); - - HloInstructionSet do_not_duplicate = - ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); - auto fusion_queue = - GetFusionQueue(computation_, [&](HloInstruction* producer) { - return do_not_duplicate.count(producer) > 0; - }); + reachability_ = HloReachabilityMap::Build(computation_); + + HloInstructionSet do_not_duplicate; + // If we allow duplications, we need to compute which instructions we do not + // want to duplicate based on a global analysis of the graph. + if (may_duplicate_) { + do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + } + auto fusion_queue = GetFusionQueue(computation_); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -489,9 +495,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { HloInstruction* fusion_instruction; // Try "regular" fusion if the operand may be duplicated. Otherwise, // perform multi-output fusion, unless this creates a cycle. - // TODO(tjoerg): Consider making multi-output fusion the default. - if (ShouldFuse(instruction, i) && - do_not_duplicate.count(operand) == 0) { + if (do_not_duplicate.count(operand) == 0 && + ShouldFuse(instruction, i)) { fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && @@ -565,15 +570,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return absl::c_any_of( - consumer->operands(), [&](const HloInstruction* consumer_operand) { - // The fusion algorithm traverses the HLO graph in reverse post order. - // Thus `cosumers` is visited before its operands (including - // `producer`). Therefore, consumer operands cannot have been fused yet. - // It is thus safe to use the pre-computed reachability map. - return consumer_operand != producer && - reachability_->IsReachable(producer, consumer_operand); - }); + auto is_reachable = [&](const HloInstruction* a, const HloInstruction* b) { + // A consumer operand may have been multi-output fused into a parallel + // consumer and thus be missing from the original reachability map. + if (!reachability_->IsPresent(a) || !reachability_->IsPresent(b)) { + reachability_ = HloReachabilityMap::Build(consumer->parent()); + } + return reachability_->IsReachable(a, b); + }; + return absl::c_any_of(consumer->operands(), + [&](const HloInstruction* consumer_operand) { + return consumer_operand != producer && + is_reachable(producer, consumer_operand); + }); } bool InstructionFusion::ShouldFuse(HloInstruction* consumer, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f14c6675208c72112aea0179c238b58709d625b5..198bd7fce5f392e5e895b959523d4fe9cf208ba2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -54,8 +55,7 @@ class InstructionFusion : public HloModulePass { // fused. The default implementation processes consumers in reverse post // order. virtual std::unique_ptr GetFusionQueue( - HloComputation* computation, - const std::function& skip_producer); + HloComputation* computation); // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. @@ -111,6 +111,10 @@ class InstructionFusion : public HloModulePass { return is_expensive_(instruction); } + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; @@ -145,10 +149,6 @@ class InstructionFusion : public HloModulePass { // duplicated. std::function is_expensive_; - // Whether multi-output fusion would introduce a cycle into the HLO graph. - bool MultiOutputFusionCreatesCycle(HloInstruction* producer, - HloInstruction* consumer); - // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index da1ad90959dc0ab1a840b3390281ce9d4999651e..58b7135cea7419f13d60ed510ecf7a88126aee48 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -117,7 +117,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -133,7 +133,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) { auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(reshape1, computation->root_instruction()); EXPECT_FALSE( @@ -149,7 +149,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {}), param0, {})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(transpose1, computation->root_instruction()); EXPECT_FALSE( @@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_FALSE( @@ -361,7 +361,7 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) { HloInstruction* unary2 = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary2, computation->root_instruction()); EXPECT_TRUE( @@ -385,7 +385,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { HloInstruction* unary = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(unary, computation->root_instruction()); EXPECT_TRUE( @@ -394,6 +394,56 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + +TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add = f32[100] add(p0, p1) + slice1 = f32[99] slice(add), slice={[0:99:1]} + slice2 = f32[99] slice(add), slice={[1:100:1]} + ROOT add2 = f32[99] add(slice1, slice2) + })") + .ValueOrDie(); + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + // 'add' would originally need to be duplicated if fused. However after its + // two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one + // user and can now be also fused. + EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter())); +} + TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { auto module = ParseHloString(R"( diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 1484e14df10d94841c5a2e849761779f5800392d..a981d94a999e3d322986bc2bfd56a5b0b5d175fc 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,4 +1,4 @@ -licenses(["restricted"]) +licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 7c79eb7d791bc9a0743605d3171ff69c6ef41d58..3a5177c418e3af8253df228a51f2fc0901d10041 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -57,6 +57,13 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } +Status InterpreterCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, + absl::Span executors, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented("Module group compilation not supported on Interpreter"); +} + StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { @@ -76,17 +83,45 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } +StatusOr>> +InterpreterCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Module group compilation is not supported on Interpreter."); +} + StatusOr>> InterpreterCompiler::Compile( - std::vector> /*hlo_modules*/, - std::vector> /*stream_execs*/, - DeviceMemoryAllocator* /*device_allocator*/) { - return tensorflow::errors::Unimplemented( - "Compilation of multiple HLO modules is not supported on Interpreter."); + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + if (module_group->empty()) { + return std::vector>(); + } + if (module_group->size() > 1) { + return tensorflow::errors::Unimplemented( + "Compilation of multiple HLO modules is not supported on Interpreter."); + } + if (stream_exec.size() != 1 || stream_exec[0].size() != 1) { + return tensorflow::errors::Unimplemented( + "Unexpected number of StreamExecutor's."); + } + auto hlo_modules = module_group->ConsumeModules(); + TF_ASSIGN_OR_RETURN(auto module, + RunHloPasses(std::move(hlo_modules[0]), stream_exec[0][0], + device_allocator)); + TF_ASSIGN_OR_RETURN( + auto executable, + RunBackend(std::move(module), stream_exec[0][0], device_allocator)); + std::vector> ret; + ret.push_back(std::move(executable)); + return std::move(ret); } StatusOr>> InterpreterCompiler::CompileAheadOfTime( - std::vector> hlo_modules, + std::unique_ptr module_group, const AotCompilationOptions& aot_options) { return tensorflow::errors::InvalidArgument( "AOT compilation not supported on Interpreter"); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index e90ae3e818522e6e4fd9d9f5acb846800bc899ca..591272951a01a3e2aa3b615673dceced8e94f674 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -46,18 +46,26 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, + absl::Span executors, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( - std::vector> hlo_modules, + std::unique_ptr module_group, std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> - CompileAheadOfTime(std::vector> hlo_modules, + CompileAheadOfTime(std::unique_ptr module_group, const AotCompilationOptions& aot_options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index a06d6113e84630df14ff68280c248cccb9afaf06..de9204011ce5ba8a9fc2871c6bd7120b6ed371b5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -37,7 +37,7 @@ namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module, + std::unique_ptr hlo_module, std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, /*hlo_profile_index_map=*/nullptr), @@ -85,6 +85,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); + evaluator_->ResetVisitStates(); TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( *computation, arg_literals)); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 3b1ebce0c75457d65e6834c809fe488a9c4a159a..bda13d376360306c81230e41b01cefc6caff230d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -42,7 +42,7 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module, + InterpreterExecutable(std::unique_ptr hlo_module, std::unique_ptr evaluator); ~InterpreterExecutable() override; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737..e3e5fa71543baa309b3a68888b1b9bdfd43cfbd5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -78,9 +78,14 @@ port::Status XlaInterpreterExecutor::SynchronousMemcpy( return port::Status::OK(); } -bool XlaInterpreterExecutor::HostCallback(Stream *stream, - std::function callback) { - AsExecutorStream(stream)->EnqueueTask(callback); +bool XlaInterpreterExecutor::HostCallback( + Stream *stream, std::function callback) { + AsExecutorStream(stream)->EnqueueTask([callback]() { + port::Status s = callback(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index fbb99457847dca69a1901006d5d8ff713882f918..400c30515464ed5b00251fba303fef303a26b97b 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -125,7 +125,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return port::Status{port::error::UNIMPLEMENTED, ""}; } - bool HostCallback(Stream *stream, std::function callback) override; + bool HostCallback(Stream *stream, + std::function callback) override; port::Status AllocateEvent(Event *event) override { return port::Status{port::error::UNIMPLEMENTED, ""}; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index c9b40d3c6195f80a19272a0d98890049d02315b9..b0fc1af8b89d7327a00f77f471e90d143a92de7c 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -110,3 +110,5 @@ REGISTER_MODULE_INITIALIZER( // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, + interpreter_platform); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index cc4a342e9d38415599256a5eaf3f5cf757652659..eddef850cf5250b85b564c1e6c92d1cc8ecd1a43 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints( return Status::OK(); } +namespace { + +bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + return custom_call != nullptr && custom_call->layout_constrained(); +} + +} // namespace + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -434,13 +444,11 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { - Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - DCHECK(!LayoutUtil::IsPadded(instruction->shape())); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kOutfeed) { @@ -456,17 +464,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - shape_with_layout = ¶meter_layout.shape(); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + parameter_layout.shape(), instruction)); } } - } - if (shape_with_layout != nullptr) { + } else if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(*shape_with_layout, instruction)); - } - - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv) { + constraints->SetInstructionLayout(custom_call->shape(), custom_call)); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } else if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); @@ -621,31 +633,6 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); - } else if (instruction->opcode() == HloOpcode::kCustomCall) { - if (!CustomCallRequiresMajorFirstLayout(instruction)) { - continue; - } - // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( - instruction->shape().element_type(), - AsInt64Slice(instruction->shape().dimensions())); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction)); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const Shape& operand_shape = instruction->operand(i)->shape(); - // Opaque operands don't get a layout constraint. - if (ShapeUtil::IsOpaque(operand_shape)) { - continue; - } - - Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i)); - } } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -676,16 +663,18 @@ Status CheckCallLayout(HloInstruction* call, return Status::OK(); } -// Custom calls have fixed input and output layouts. -Status CheckCustomCallLayout(HloInstruction* custom_call) { - for (const HloInstruction* operand : custom_call->operands()) { - TF_RET_CHECK( - ShapeUtil::IsOpaque(operand->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); +// Operands of layout-constrained custom calls must match the expected +// constrained layouts. +Status CheckCustomCallLayout(HloInstruction* instruction) { + if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); + } } - TF_RET_CHECK( - ShapeUtil::IsOpaque(custom_call->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -932,9 +921,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - if (CustomCallRequiresMajorFirstLayout(instruction)) { - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); - } + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -971,9 +958,8 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, module->entry_computation()) .result_layout(); if (result_layout.LayoutIsSet()) { - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), - result_layout.shape())); + TF_RET_CHECK( + ShapeUtil::Equal(module->result_shape(), result_layout.shape())); } return Status::OK(); } @@ -1002,10 +988,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape()) && @@ -1264,12 +1248,20 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.operand(), constraints)); // For array-shaped operands and user instructions try to pick a minimum cost - // layout. For example, if the operand of a elementwise instruction is - // constained to a certain layout we want the output of the instruction to + // layout. For example, if the operand of an elementwise instruction is + // constrained to a certain layout we want the output of the instruction to // have the same layout. + // + // If the user is not array-shaped, we still want to propagate the layout + // to siblings if the instruction can't change layout. This is to represent + // the information that non-layout-changing instructions should have the same + // layout for the operands with the same ranks. const HloInstruction* operand = operand_constraint.operand(); const HloInstruction* user = operand_constraint.instruction(); - if (!ShapeUtil::IsArray(operand->shape()) || + if (!ShapeUtil::IsArray(operand->shape())) { + return Status::OK(); + } + if (instruction_can_change_layout_func_(user) && !ShapeUtil::IsArray(user->shape())) { return Status::OK(); } @@ -1280,52 +1272,183 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.operand_no())) { return Status::OK(); } - TF_ASSIGN_OR_RETURN( - const LogicalBuffer* buffer, - constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{})); - if (constraints->BufferLayout(*buffer) == nullptr) { - std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( - operand_constraint.shape_layout().layout(), user, - operand_constraint.operand_no()); - if (layout != nullptr) { - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false)); + int64 operand_rank = ShapeUtil::Rank(operand->shape()); + if (operand_rank <= 1) { + return Status::OK(); + } + + // Propagate layouts between operands of the same instruction. This is a + // constraint on non-layout-changing instructions. + if (!instruction_can_change_layout_func_(user)) { + // Make sure all siblings have the same layout as the operand. + for (int64 operand_no = 0; operand_no < user->operand_count(); + ++operand_no) { + if (user->operand(operand_no) == operand) { + continue; + } + const HloInstruction* sibling = user->operand(operand_no); + const int64 sibling_rank = ShapeUtil::Rank(sibling->shape()); + if (sibling_rank <= 1) { + continue; + } + if (operand_rank != sibling_rank) { + continue; + } + const OperandLayoutConstraint* constraint = + constraints->GetOperandLayoutConstraint(user, operand_no); + if (constraint != nullptr) { + // Due to the DFS of the propagation we can end up here when operand_no + // has a layout set that hasn't been propagated yet (is still on the + // stack of layouts to propagate). + // We can continue here and leave the operands with different layouts, + // as we will either: + // - overwrite the current operand when the DFS gets back to propagating + // operand(operand_no) to its siblings + // - overwrite operand(operand_no)'s layout with a mandatory layout if + // we continue to propagate our layout to the result, and then + // backwards into all operands (if the result is an array of rank > 1) + continue; + } + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + operand_constraint.shape_layout().layout(), user, operand_no, + /*mandatory=*/false)); } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + user->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsTuple(subshape)) { + return Status::OK(); + } + if (ShapeUtil::Rank(subshape) <= 1) { + return Status::OK(); + } + + // Assign the right layout to input fusion of higher rank reduce + // operations. + if (ShapeUtil::Rank(subshape) != ShapeUtil::Rank(operand->shape())) { + return Status::OK(); + } + // TODO(b/67641796): Are there cases except fusion that use this code + // path? + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt( + user, shape_index)); + // Make sure the output has the same layout as the operand. + const BufferLayoutConstraint* constraint = + constraints->GetBufferLayoutConstraint(*buffer); + // If we already have a constraint for the buffer it was assigned but + // hasn't propagated yet. This can happen with diamond-shaped graphs + // where one path is first evaluated in depth-first order (we're here) + // and the other path is propagated later. We don't set the layout + // here as it will always be overwritten later. + if (constraint == nullptr) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + operand_constraint.shape_layout().layout(), *buffer, + /*mandatory=*/false)); + } + return Status::OK(); + })); + return Status::OK(); } + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsTuple(subshape)) { + return Status::OK(); + } + if (ShapeUtil::Rank(subshape) <= 1) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt(user, + shape_index)); + if (constraints->BufferLayout(*buffer) == nullptr || + !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) { + std::unique_ptr layout = ChooseOutputLayoutFromOperandLayout( + operand_constraint.shape_layout().layout(), user, + operand_constraint.operand_no()); + if (layout != nullptr) { + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + *layout, *buffer, + /*mandatory=*/user->opcode() == HloOpcode::kReduce, + /*dfs=*/false)); + } + } + return Status::OK(); + })); return Status::OK(); } -Status LayoutAssignment::PropagateBufferConstraint( +Status LayoutAssignment::PropagateBufferConstraintToOperands( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) { - // Only propagate array layouts. + VLOG(5) << "PropagateBufferConstraintToOperands: " + << buffer_constraint.ToString(); const LogicalBuffer& buffer = buffer_constraint.buffer(); - if (!buffer.IsArray()) { + + const HloInstruction* instruction = buffer.instruction(); + if (IsAtMostRank1(instruction->shape())) { return Status::OK(); } - // If this buffer is the result of an array-shaped op (as opposed to an array - // element in a tuple) try to propagate the layout to its operands. - if (buffer.IsTopLevel()) { - const HloInstruction* instruction = buffer.instruction(); - // Propagate the def-constraint on an instruction to the use-constraints on - // its operands (use-def propagation). - for (int64 operand_no = 0; operand_no < instruction->operand_count(); - ++operand_no) { - if (constraints->OperandLayout(instruction, operand_no) == nullptr && - ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + if (IsAtMostRank1(operand->shape())) { + continue; + } + if (!instruction_can_change_layout_func_(instruction)) { + // Copy the layout to the operand. + if (buffer.IsArray() && ShapeUtil::IsArray(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == + LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) { + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + buffer_constraint.layout(), instruction, operand_no, + /*mandatory=*/true)); + } + } else { + if (!buffer.IsTopLevel() || + !ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; // Don't touch buffers that are internal to a tuple. + } + VLOG(6) << "Propagating constraint to operand " << operand_no << " of " + << instruction->ToShortString(); + // Assign a layout if there is no constraint already. + const OperandLayoutConstraint* constraint = + constraints->GetOperandLayoutConstraint(instruction, operand_no); + if (constraint == nullptr || !constraint->mandatory()) { std::unique_ptr operand_layout = ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(), instruction, operand_no); if (operand_layout != nullptr) { + // Do not propagate operand constraints of transposes and reshapes, it + // tends to create really bad layouts. TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - *operand_layout, instruction, operand_no, /*mandatory=*/true)); + *operand_layout, instruction, operand_no, /*mandatory=*/false, + /*dfs=*/false)); } + } else { + VLOG(6) << "Operand already has a constraint " + << constraint->ToString(); } } } - return PropagateBufferConstraintToUses(buffer_constraint, constraints); + return Status::OK(); +} + +Status LayoutAssignment::PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + // Only propagate array layouts. + const LogicalBuffer& buffer = buffer_constraint.buffer(); + if (!buffer.IsArray()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + PropagateBufferConstraintToUses(buffer_constraint, constraints)); + return PropagateBufferConstraintToOperands(buffer_constraint, constraints); } Status LayoutAssignment::PropagateBufferConstraintToUses( @@ -1353,12 +1476,12 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( } Status LayoutAssignment::PropagateResultConstraint( - const ResultLayoutConstraint& result_constraint, + const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { // Propagate the use constraint of the root instruction up to the logical // buffers which make up the result. return PropagateUseConstraintToDefs( - result_constraint.shape_layout(), + layout_constraint.shape_layout(), constraints->computation()->root_instruction(), constraints); } @@ -1536,6 +1659,10 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Execute extra verification step once the layout has been finalized. TF_RETURN_IF_ERROR(Verify(instruction)); + // Shape must be valid. + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } @@ -1554,11 +1681,11 @@ Status LayoutAssignment::CalculateComputationLayout( Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidentally use the existing layout. + // by the LayoutAssignment pass, except for those on parameters, the + // computation result, and a couple special cases. The former two are + // specified in computation_layout. Clearing the layouts here avoids hiding + // potential bugs in the layout assignment pass that may accidentally use the + // existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction @@ -1567,7 +1694,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } - if (instruction->opcode() != HloOpcode::kInfeed) { + // Some instructions carry mandatory layouts in their shape. + if (instruction->opcode() != HloOpcode::kInfeed && + !IsLayoutConstrainedCustomCall(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -1802,6 +1931,18 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(Init()); + // Verify computation layout is sane. + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry_computation_layout_->parameter_count() == + entry->num_parameters()); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + TF_RET_CHECK( + ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i), + entry->parameter_instruction(i)->shape())); + } + TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(), + entry->root_instruction()->shape())); + // We do two passes. The first one we pass a nullptr ComputationLayout to // the RunOnComputation() calls (for non entry computations), and we register // the ComputationLayout which are naturally flowing in DFS fashion to the @@ -1859,6 +2000,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( switch (instruction->opcode()) { case HloOpcode::kAbs: case HloOpcode::kAdd: + case HloOpcode::kAddDependency: case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kBitcastConvert: @@ -1873,7 +2015,6 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: - case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1907,6 +2048,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: case HloOpcode::kShiftLeft: @@ -1930,6 +2072,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: + case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kFusion: @@ -1944,17 +2087,27 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kReduce: case HloOpcode::kReshape: case HloOpcode::kRng: - case HloOpcode::kScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kAfterAll: case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: return true; } } +/* static */ +bool LayoutAssignment::IsAtMostRank1(const Shape& shape) { + if (ShapeUtil::IsArray(shape)) { + return ShapeUtil::Rank(shape) <= 1; + } + return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) { + return IsAtMostRank1(subshape); + }); +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 2d48e122637c080fc2bcf7bce1c2a2521f51e41f..3b081de3c7826c3c11a7d87d542835d0ecce1b7e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -315,6 +315,10 @@ class LayoutAssignment : public HloModulePass { // rank as the output to have the same layout as the output. static bool InstructionCanChangeLayout(const HloInstruction* instruction); + // In case of an array shape returns true iff it is at most rank 1. In case of + // a tuple shape returns true iff all leaf shapes are at most rank 1. + static bool IsAtMostRank1(const Shape& shape); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize @@ -333,19 +337,6 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); - // By default LayoutAssignment ensures that inputs and outputs of CustomCalls - // have the "major-first" layout (i.e. {n, n-1, ..., 0}). - // - // If this function returns true, LayoutAssignment does not set a layout for - // the given CustomCall. It's up to the backend to set one in - // AddBackendConstraints, if necessary. - // - // Precondition: instruction->opcode() == HloOpcode::kCustomCall. - virtual bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* /*instruction*/) { - return true; - } - // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -375,7 +366,7 @@ class LayoutAssignment : public HloModulePass { // `user` that minimizes its cost on that operand. Returns null if it can't // decide the best layout. // Precondition: `user` and the operand are array-shaped. - std::unique_ptr ChooseOutputLayoutFromOperandLayout( + virtual std::unique_ptr ChooseOutputLayoutFromOperandLayout( const Layout& operand_layout, const HloInstruction* user, int64 operand_no); @@ -421,6 +412,10 @@ class LayoutAssignment : public HloModulePass { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); + Status PropagateBufferConstraintToOperands( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints); + // Check that all layouts in the module have been set and satisfy all // necessary conditions. Status CheckLayouts(HloModule* module); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 2c549cd872b35e55cc00527b6579f79d8516b66c..5c661bfacb08fe27f3cbdc1fb9db083315166008 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -27,50 +27,70 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace m = xla::match; using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloVerifiedTestBase { +class LayoutAssignmentTest : public HloTestBase { protected: - void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout, + void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, /*channel_constraints=*/channel_constraints); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m).status()); } - std::vector LayoutOf(HloModule* module, absl::string_view name) { + std::vector LayoutOf(HloModule* m, absl::string_view name) { auto minor_to_major = - FindInstruction(module, name)->shape().layout().minor_to_major(); + FindInstruction(m, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list> minor_to_majors) { + int i = 0; + for (const absl::Span minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { // Verify the layouts of the root and parameter instructions of a computation // match the ComputationLayout for two different layouts. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); @@ -80,8 +100,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout layout = LayoutUtil::MakeLayout(minor_to_major); Shape shape(ashape); @@ -92,7 +112,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -110,8 +130,8 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { HloInstruction::CreateParameter(1, ashape, "param1")); builder.AddInstruction( HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); Layout col_major_layout = LayoutUtil::MakeLayout({1, 0}); Shape col_major_shape(ashape); @@ -128,7 +148,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -139,7 +159,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { // Verify that the layout of the fused parameters in a fusion instruction // match that of the fusion operands. Other fused instructions should have no // layout. - std::vector> minor_to_majors = {{0, 1}, {1, 0}}; + std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); auto constant_literal1 = LiteralUtil::CreateR2WithLayout( @@ -159,8 +179,8 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build()); auto fusion = computation->CreateFusionInstruction( {negate2, negate1, add}, HloInstruction::FusionKind::kLoop); @@ -173,7 +193,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -208,13 +228,13 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { auto negate = builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kNegate, get_element0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -246,17 +266,17 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()}); TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -281,11 +301,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { auto nested_tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, inner_tuple})); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape result_shape = nested_tuple->shape(); *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); @@ -295,7 +315,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -308,12 +328,11 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // %tuple.1 = Tuple(%copy) layout=({0,1}) // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1})) // - EXPECT_TRUE( - AlgebraicSimplifier(/*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return false; }) - .Run(module) - .ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + AlgebraicSimplifierOptions options( + [](const Shape&, const Shape&) { return false; }); + options.set_is_layout_sensitive(true); + EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); + HloInstruction* root = m->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape())); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}), @@ -323,7 +342,8 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // Verify the structure of the HLO graph. EXPECT_THAT(root, - op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); + GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)), + m::Tuple(m::Copy(m::Op().Is(constant)))))); } TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { @@ -340,9 +360,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -353,7 +372,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -382,8 +401,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { HloInstruction::CreateTranspose(bshape, log, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build(tanh)); + auto m = CreateNewVerifiedModule(); + auto computation = m->AddEntryComputation(builder.Build(tanh)); Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); @@ -394,7 +413,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -418,9 +437,9 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { HloInstruction::CreateBroadcast(bshape, param, {1, 2})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); Shape input_shape_with_layout(ashape); Shape output_shape_with_layout(cshape); @@ -433,7 +452,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -467,9 +486,8 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(tuple)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(tuple)); ComputationLayout computation_layout(computation->ComputeProgramShape()); Shape param_shape_with_layout(f32_4); @@ -486,7 +504,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -537,9 +555,8 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); auto reshape = builder.AddInstruction( HloInstruction::CreateReshape(cshape, concatenate)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(reshape)); + auto m = CreateNewVerifiedModule(); + HloComputation* computation = m->AddEntryComputation(builder.Build(reshape)); Shape param0_shape_with_layout(ashape); Shape param1_shape_with_layout(ashape); @@ -552,7 +569,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module).status()); + EXPECT_IS_OK(layout_assignment.Run(m.get()).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -572,11 +589,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -590,11 +607,11 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloInstruction::CreateBroadcast(input_shape, constant, {})); auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); HloComputation* computation = - module->AddEntryComputation(builder.Build(transpose)); + m->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -660,12 +677,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - ParseAndVerifyModule(module_str); - + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); @@ -700,9 +717,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -714,19 +732,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); - EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(&module(), "gte1") + EXPECT_THAT(FindInstruction(m.get(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -736,7 +754,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { auto builder = HloComputation::Builder(TestName()); - auto module = CreateNewModule(); + auto m = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); @@ -763,7 +781,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { true_builder.AddInstruction(HloInstruction::CreateTuple({add})); } HloComputation* true_computation = - module->AddEmbeddedComputation(true_builder.Build()); + m->AddEmbeddedComputation(true_builder.Build()); auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); { @@ -779,14 +797,14 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data})); } HloComputation* false_computation = - module->AddEmbeddedComputation(false_builder.Build()); + m->AddEmbeddedComputation(false_builder.Build()); builder.AddInstruction(HloInstruction::CreateConditional( result_tshape, pred, tuple, true_computation, tuple, false_computation)); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module, &computation_layout); + AssignLayouts(m.get(), &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -807,13 +825,13 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); builder.AddInstruction(HloInstruction::CreateUnary( constant0->shape(), HloOpcode::kBitcast, constant0)); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module).status(); + Status error_status = layout_assignment.Run(m.get()).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -840,9 +858,10 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -852,12 +871,12 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(&module(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0)); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}), ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } @@ -876,17 +895,17 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { param = (f32[2,2]) parameter(0) gte = f32[2,2] get-tuple-element(param), index=0 ar.0 = f32[2,2] cross-replica-sum(gte), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=0} const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) ROOT ar.1 = f32[2,2] cross-replica-sum(const), - all_reduce_id=0, replica_groups={{0}}, to_apply=add, + all_reduce_id=1, replica_groups={{0}}, to_apply=add, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -896,12 +915,12 @@ TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(m.get(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); - const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); } @@ -917,19 +936,22 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); - EXPECT_THAT(root, op::Add(op::Parameter(), - op::Slice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy))))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -945,21 +967,23 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); EXPECT_THAT(root, - op::Add(op::Parameter(), - op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + GmockMatch(m::Add( + m::Parameter(), + m::DynamicSlice( + m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -976,21 +1000,23 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); - EXPECT_THAT(root, - op::Add(op::Parameter(), - op::Concatenate(AllOf(op::Copy(op::Parameter(1)), - op::ShapeWithLayout(shape_copy)), - op::Parameter(2)))); + EXPECT_THAT( + root, + GmockMatch(m::Add( + m::Parameter(), + m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy), + m::Parameter(2))))); } TEST_F(LayoutAssignmentTest, @@ -1007,16 +1033,18 @@ TEST_F(LayoutAssignmentTest, } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); + EXPECT_THAT(root, + GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1)))); } TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { @@ -1029,18 +1057,20 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); auto compiled_module = backend() .compiler() - ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + ->RunHloPasses(m->Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); HloInstruction* root = compiled_module->entry_computation()->root_instruction(); Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); - EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), - op::ShapeWithLayout(shape_copy)))); + EXPECT_THAT(root, + GmockMatch(m::Slice( + m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy)))); } TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { @@ -1086,20 +1116,241 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { } )"; - ParseAndVerifyModule(module_str); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str)); ComputationLayout computation_layout( - module().entry_computation()->ComputeProgramShape()); + m->entry_computation()->ComputeProgramShape()); // Sanity check to verify that there's a layout mismatch. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); - AssignLayouts(&module(), &computation_layout); + AssignLayouts(m.get(), &computation_layout); // Make sure that layout assignment did not magically eliminate the mismatch, // in which case the test didn't prove anything. - EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0)); +} + +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(m.get(), &computation_layout); + + HloInstruction* root = m->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(m.get(), &computation_layout); + + HloInstruction* root = m->entry_computation()->root_instruction(); + ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter()))); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(m.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter())))); + + const HloInstruction* custom_call = + m->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(m.get(), &computation_layout); + + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Copy(m::CustomCall()))); + + const HloInstruction* custom_call = + m->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(m.get(), &computation_layout); + + HloInstruction* root = m->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Copy(m::CustomCall(m::Tuple())))); + + const HloInstruction* custom_call = + m->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = m->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(m.get(), &computation_layout); + + ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); +} + +Status AssignLayoutsToComputation( + HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) { + if (!m->entry_computation_layout().result_layout().LayoutIsSet()) { + m->mutable_entry_computation_layout() + ->mutable_result_layout() + ->SetToDefaultLayout(); + } + LayoutAssignment layout_assignment( + m->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, channel_constraints); + return layout_assignment.Run(m).status(); +} + +TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) { + // Check that we handle a diamond-shaped graph correctly. + // transpose + // / \ + // add | + // \ / + // tuple + + auto b = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {12, 8}); + Shape bshape = ShapeUtil::MakeShape(F32, {8, 12}); + auto param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input")); + auto param1 = + b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input")); + auto transpose = + b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0})); + auto add = b.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1)); + b.AddInstruction(HloInstruction::CreateTuple({add, transpose})); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(b.Build()); + Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0}); + Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1}); + *m->mutable_entry_computation_layout()->mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor})); + const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0}); + ForceParameterLayout(m.get(), 0, r2_dim0major); + ForceParameterLayout(m.get(), 1, r2_dim0major); + TF_ASSERT_OK(AssignLayoutsToComputation(m.get())); + + EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0)); + EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + + EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index b17c9d504501a907e27d5152e0082799e87443c7..382b575120277ffb0e63e693757591681a78479e 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -21,8 +21,25 @@ limitations under the License. #endif namespace xla { +Status LLVMCompiler::RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, + absl::Span executors, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + +StatusOr>> +LLVMCompiler::RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); +} + StatusOr>> LLVMCompiler::Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: @@ -38,6 +55,8 @@ StatusOr>> LLVMCompiler::Compile( tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; std::vector> result; + std::vector> modules = + module_group->ConsumeModules(); for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { return Unimplemented( diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index f1c623508c5307f2b1c036d3ec6823b75c7eda13..182d8edbe30da292f28aeab53be646ce6651839f 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -69,8 +69,18 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; + Status RunHloPassesOnModuleGroup( + HloModuleGroup* module_group, + absl::Span executors, + DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> RunBackendOnModuleGroup( + std::unique_ptr module_group, + std::vector> stream_exec, + DeviceMemoryAllocator* device_allocator) override; + StatusOr>> Compile( - std::vector> modules, + std::unique_ptr module_group, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 6223a34b1258961944a3ac64cd10876d1272c94e..728a66b388f0f9af480ff88b5e96990a26e36af5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm//:core", @@ -168,6 +169,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm//:core", ], @@ -196,15 +198,17 @@ cc_library( hdrs = ["sort_util.h"], deps = [ ":ir_array", + ":kernel_support_library", ":llvm_loop", ":llvm_util", ":loop_emitter", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index cc2e862f2eb9a49099c5f90efe1b29fb77c8f106..4d7f36d9f8b565a819edf0631efc5c7a58c4f87f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -130,7 +130,8 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -160,7 +161,8 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); // Create element generators for update and start_indices. - FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter); + FusedIrEmitter fused_emitter(std::move(operand_arrays_generator), + elemental_emitter); TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); ElementGenerator start_indices_generator = @@ -173,21 +175,24 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + fusion, std::move(operand_arrays_generator), fusion_output_array, + elemental_emitter, /*launch_dimensions=*/nullptr, b); } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( - fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, - &launch_dimensions, b); + fusion, std::move(operand_arrays_generator), fusion_output_array, + elemental_emitter, &launch_dimensions, b); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index fb3e4eb97cae06f2a0c87dd7118b8332048df56e..7fe803d1f8da5251c99f0a8fd97f99e9ca031175 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -27,6 +27,9 @@ limitations under the License. namespace xla { namespace llvm_ir { +using GeneratorForOperandIrArrays = + std::function()>; + // Checks if we can emit code for the given DynamicUpdateSlice node that updates // its input in place. Returns true if the dynamic-update-slice's // array-to-be-updated and output share the same BufferAllocation::Slice. @@ -73,14 +76,16 @@ Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, absl::Span fusion_operand_arrays, + HloInstruction* fusion, + GeneratorForOperandIrArrays operand_arrays_generator, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index b606c993a2d58a6d177af10de7b214de130c2279..38f2b5da23a7b92e4547dceaba011ce654977da3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -33,7 +33,7 @@ namespace xla { using llvm_ir::IrArray; Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { - generators_[hlo] = + indexed_generators_[hlo] = [=](const IrArray::Index& index) -> StatusOr { if (generated_value_cache_[hlo].count(index.multidim()) > 0) { llvm::Value* generated_value = @@ -63,25 +63,26 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { << llvm_ir::AsString(b_->GetInsertBlock()->getName()) << ")."; } - TF_ASSIGN_OR_RETURN( - generated_value_cache_[hlo][index.multidim()], - elemental_emitter_->MakeElementGenerator(hlo, generators_)(index)); + TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()], + elemental_emitter_->MakeElementGenerator( + hlo, indexed_generators_)(index)); return generated_value_cache_[hlo][index.multidim()]; }; return Status::OK(); } Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { - const Literal& literal = constant->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); - llvm::GlobalVariable* global = new llvm::GlobalVariable( - *b_->GetInsertBlock()->getModule(), initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, - /*Name=*/""); - llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( - global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); - generators_[constant] = [=](const IrArray::Index& index) { + indexed_generators_[constant] = [=](const IrArray::Index& index) { + const Literal& literal = constant->literal(); + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *b_->GetInsertBlock()->getModule(), initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, + /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, b_); }; @@ -91,34 +92,47 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { Status FusedIrEmitter::HandleGetTupleElement( HloInstruction* get_tuple_element) { - // Lookup ir value for 'operand'. - auto operand = get_tuple_element->operand(0); - auto it = gte_values_.find(operand); - if (it == gte_values_.end()) { - return Unimplemented( - "GetTupleElement fusion currently only supports" - " parameter operands, but found operand: %s", - operand->name()); - } - // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. - llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( - get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, it->second, b_, module_); - gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); - // Emit code to read base tuple element array (if non-tuple shaped). + auto emit_tuple_element_ptr = [=]() -> StatusOr { + const HloInstruction* tuple_operand = get_tuple_element->operand(0); + llvm::Value* tuple_ptr; + if (tuple_operand->opcode() == HloOpcode::kGetTupleElement) { + TF_ASSIGN_OR_RETURN(tuple_ptr, non_indexed_generators_[tuple_operand]()); + } else { + if (tuple_operand->opcode() != HloOpcode::kParameter) { + return Unimplemented( + "GetTupleElement fusion currently only supports parameter or " + "nested" + "GetTupleElement as tuple operand, found an exception: %s", + tuple_operand->name()); + } + tuple_ptr = + GetBasePointerForFusedParameter(tuple_operand->parameter_number()); + } + + // Lookup tuple element pointer. + return llvm_ir::EmitGetTupleElement( + get_tuple_element->shape(), get_tuple_element->tuple_index(), + /*alignment=*/1, tuple_ptr, b_, module_); + }; + if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { - generators_[get_tuple_element] = + indexed_generators_[get_tuple_element] = [=](const IrArray::Index& index) -> StatusOr { // TODO(b/34080002) Add aliasing information to tuple element IrArray. + TF_ASSIGN_OR_RETURN(llvm::Value * tuple_element_ptr, + emit_tuple_element_ptr()); return IrArray(tuple_element_ptr, get_tuple_element->shape()) .EmitReadArrayElement(index, b_); }; + } else { + non_indexed_generators_[get_tuple_element] = emit_tuple_element_ptr; } return Status::OK(); } Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { - generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* { + indexed_generators_[parameter] = + [=](const IrArray::Index& index) -> llvm::Value* { if (tiled_parameter_info_) { if (llvm::Value* param_tile_buffer = tiled_parameter_info_->GetBufferForParameter( @@ -135,14 +149,9 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { "tiled_buffer"); } } - return parameter_arrays_[parameter->parameter_number()] + return GetIrArrayForFusedParameter(parameter->parameter_number()) .EmitReadArrayElement(index, b_); }; - // Store ir value for fusion operand associated with fusion parameter to be - // accessed by subsequent fused GetTupleElement instructions. - gte_values_.insert(std::make_pair( - parameter, - parameter_arrays_[parameter->parameter_number()].GetBasePointer())); return Status::OK(); } @@ -153,12 +162,13 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( operand->shape().element_type(), module_)); } - generators_[tuple] = + indexed_generators_[tuple] = [=](const IrArray::Index& index) -> StatusOr { llvm::Value* ret = llvm::UndefValue::get( llvm::StructType::get(b_->getContext(), operand_elemental_ir_types)); for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value * val_i, generators_[operands[i]](index)); + TF_ASSIGN_OR_RETURN(llvm::Value * val_i, + indexed_generators_[operands[i]](index)); ret = b_->CreateInsertValue(ret, val_i, i); } return ret; @@ -171,15 +181,15 @@ Status FusedIrEmitter::FinishVisit(HloInstruction* root) { return Status::OK(); } -FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { +FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const { CHECK_NE(nullptr, fused_root_) << "GetRootGenerator should be called after Accept."; - return generators_.at(fused_root_); + return indexed_generators_.at(fused_root_); } -FusedIrEmitter::Generator FusedIrEmitter::GetGenerator( +FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator( const HloInstruction* instruction) const { - return generators_.at(instruction); + return indexed_generators_.at(instruction); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 44d21fa750a532633f46614002d59c90fc0b5d40..1b9c61f6700e2a1309b21e499f4a9e2439ed3702 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -52,11 +53,15 @@ namespace xla { // same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: - using Generator = llvm_ir::ElementGenerator; + using IndexedGenerator = llvm_ir::ElementGenerator; + using NonIndexedGenerator = std::function()>; + using GeneratorForOperandIrArrays = + std::function()>; - FusedIrEmitter(absl::Span parameter_arrays, + FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, ElementalIrEmitter* elemental_emitter) - : parameter_arrays_(parameter_arrays), + : operand_arrays_(), + operand_arrays_generator_(std::move(operand_arrays_generator)), tiled_parameter_info_(nullptr), elemental_emitter_(elemental_emitter), b_(elemental_emitter->b()), @@ -76,25 +81,34 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { Status FinishVisit(HloInstruction* root) override; // Returns the generator function for the root of the fused computation. - Generator GetRootGenerator() const; + IndexedGenerator GetRootGenerator() const; // Returns the generator function for the given instruction. - Generator GetGenerator(const HloInstruction* instruction) const; - - // Returns the ir value for instruction 'hlo'. - llvm::Value* GetIrValueForGTE(const HloInstruction* hlo) const { - auto it = gte_values_.find(hlo); - CHECK(it != gte_values_.end()); - return it->second; - } + IndexedGenerator GetGenerator(const HloInstruction* instruction) const; void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) { tiled_parameter_info_ = info; } + protected: + // Returns the IrArrays for the fusion instruction operands. + llvm_ir::IrArray& GetIrArrayForFusedParameter(int64 parameter_number) { + if (!operand_arrays_.has_value()) { + operand_arrays_ = operand_arrays_generator_(); + } + return operand_arrays_.value()[parameter_number]; + } + + llvm::Value* GetBasePointerForFusedParameter(int64 parameter_number) { + return GetIrArrayForFusedParameter(parameter_number).GetBasePointer(); + } + private: - // Arrays of parameters of fusion instruction - absl::Span parameter_arrays_; + // IrArrays for the fusion instruction operands, whose base addresses are the + // base address of the corresponding parameters in the fused computation. + absl::optional> operand_arrays_; + GeneratorForOperandIrArrays operand_arrays_generator_; + const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; @@ -106,19 +120,23 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b_; llvm::Module* module_; - // Map from instruction pointers to functions to generate elements of their - // outputs - std::unordered_map generators_; + // Map from instructions to functions that generate code for the output + // elements. If an instruction is a GetTupleElement instruction, the + // instruction produces non-tuple result. + std::unordered_map + indexed_generators_; + + // Map from tuple-result-producing GetTupleELement instructions to functions + // that generate the base pointers for the output elements. This is used to + // support the translation of nested GetTupleElement instructions. + std::unordered_map + non_indexed_generators_; // Cache of generated values, lest we regenerate an element of a node with // multiple outgoing edges std::unordered_map, llvm::Value*>> generated_value_cache_; - - // Stores ir values required to emit fused (and possibly nested) - // GetTupleElement instructions. - std::unordered_map gte_values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc..d6d84994ee147f4b8c1a333b0eaccdf6e0a2219b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -108,6 +109,14 @@ class IrArray { Index(absl::Span multidim, llvm::Value* linear, const Shape& shape); + // Returns an index that adds `addend` to the given `dim` of the object. + Index AddOffsetToDim(llvm::Value* addend, int64 dim, + llvm::IRBuilder<>* b) const { + IrArray::Index index = *this; + index[dim] = b->CreateAdd(index[dim], addend); + return index; + } + const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738..c26711e526c9b89cdedcb6aed9f93d41dd25dc83 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -52,6 +52,29 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), dimensions); } + +// Given an index for a shape, return the equivalent new index if the shape is +// reshaped to another shape. +IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape, + const Shape& reshaped_shape, + llvm::IRBuilder<>* b) { + auto bounds = shape.dimensions(); + auto minor_to_major = shape.layout().minor_to_major(); + llvm::Value* linear_index = index.GetConstantWithIndexType(0); + int64 multiplier = 1; + for (int i = 0; i < index.size(); ++i) { + int64 dim = minor_to_major[i]; + llvm::Value* addend = b->CreateMul( + index[dim], index.GetConstantWithIndexType(multiplier), "linearizing", + /*HasNUW=*/true, /*HasNSW=*/true); + linear_index = b->CreateAdd(linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + multiplier *= bounds[dim]; + } + + return IrArray::Index(linear_index, reshaped_shape, b); +} + } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -60,28 +83,30 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } - std::vector perm(a.dimensions().size()); - { - auto layout_a_orig = LayoutUtil::MinorToMajor(a); - std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); - auto layout_b_orig = LayoutUtil::MinorToMajor(b); - std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); - for (size_t i = 0; i < perm.size(); ++i) { - perm[i] = PositionInContainer(layout_b, layout_a[i]); - } + std::vector permutation(a.dimensions().size()); + absl::Span minor_to_major_a = LayoutUtil::MinorToMajor(a); + std::vector major_to_minor_a(minor_to_major_a.rbegin(), + minor_to_major_a.rend()); + absl::Span minor_to_major_b = LayoutUtil::MinorToMajor(b); + std::vector major_to_minor_b(minor_to_major_b.rbegin(), + minor_to_major_b.rend()); + for (size_t i = 0; i < permutation.size(); ++i) { + permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]); } - auto segs = ConsecutiveSegments(perm); - if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) { - Shape norm_a = + + std::vector segments = ConsecutiveSegments(permutation); + if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) { + Shape descending_layout_shape = ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); - Shape reduced_a = MergeDimensions(segs, norm_a); - auto reduced_a_dims = reduced_a.dimensions(); + Shape normalized_shape = MergeDimensions(segments, descending_layout_shape); + absl::Span normalized_dims = + AsInt64Slice(normalized_shape.dimensions()); std::vector dims_021; - if (2 == segs.size()) { + if (2 == segments.size()) { // The logical component-0 is of size one. - dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]}; + dims_021 = {1, normalized_dims[1], normalized_dims[0]}; } else { - dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]}; + dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]}; } return dims_021; @@ -90,27 +115,117 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b) { - auto bounds = reduced_output_shape.dimensions(); - auto minor_to_major = reduced_output_shape.layout().minor_to_major(); - llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0); - int64 multiplier = 1; - for (int i = 0; i < reduced_output_index.size(); ++i) { - int64 dim = minor_to_major[i]; - llvm::Value* addend = - b->CreateMul(reduced_output_index[dim], - reduced_output_index.GetConstantWithIndexType(multiplier), - "linearizing", - /*HasNUW=*/true, /*HasNSW=*/true); - linear_index = b->CreateAdd(linear_index, addend, "", - /*HasNUW=*/true, /*HasNSW=*/true); - multiplier *= bounds[dim]; +KernelMappingScheme::KernelMappingScheme( + absl::Span dims_in_elems, int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, int64 num_threads_y, + int64 num_threads_x, llvm::IRBuilder<>* b) + : b_(b), + dims_in_elems_(dims_in_elems), + tile_sizes_{1, tile_size_y, tile_size_x}, + num_threads_x_(num_threads_x), + num_threads_y_(num_threads_y) { + DCHECK_EQ(dims_in_elems_.size(), 3); + DCHECK_EQ(req_block_sizes.size(), 3); + + DCHECK_EQ(tile_size_y % num_threads_y_, 0); + DCHECK_EQ(tile_size_x % num_threads_x_, 0); + + dims_in_tiles_ = ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_); + block_sizes_.reserve(req_block_sizes.size()); + absl::c_transform(req_block_sizes, dims_in_tiles_, + std::back_inserter(block_sizes_), + [](const int64 requested_size, const int64 max_size) { + return std::min(requested_size, max_size); + }); + dims_in_blocks_ = ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_); + + VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; + VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; + VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") + << "]"; +} + +IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape) { + DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size()); + Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + unnormalized_shape.element_type(), GetDimensionsInElements()); + return GetReshapedIndex(normalized_shape_index, output_shape, + unnormalized_shape, b_); +} + +IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) { + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(), + llvm::cast(block_id)); + llvm::Value* linear_block_id = + b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x"); + return IrArray::Index(linear_block_id, + ShapeUtil::MakeShapeWithDescendingLayout( + PRED /*arbitrary*/, dims_in_blocks_), + b_); +} + +IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin( + const IrArray::Index& block_index) { + IrArray::Index tile_index = block_index; + for (int i = 0; i < block_sizes_.size(); ++i) { + tile_index[i] = b_->CreateMul( + block_index[i], + llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]), + "block_origin." + std::to_string(i)); + } + return tile_index; +} + +IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin( + const IrArray::Index& tile_index) { + IrArray::Index elem_index = tile_index; + for (int i = DimY; i < DimTot; ++i) { + elem_index[i] = + b_->CreateMul(tile_index[i], + llvm::ConstantInt::get(tile_index[i]->getType(), + GetTileSizeForDimension(i)), + "tile_origin." + std::to_string(i)); } + return elem_index; +} + +llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name) { + // If shared memory tranpose is needed, we use square tiles. + CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY()); + + // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is + // organized into 32-way. We usually use the warp size or a multiplier or a + // the warp size as the size for tiling. This may cause all elements in the + // same column of a tile use the same memory bank and therefore shared memory + // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer + // can reduce such shared memory bank conflicts. + llvm::Type* buffer_type = llvm::ArrayType::get( + llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1), + GetTileSizeForDimension(DimY)); + return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(), + buffer_type, buffer_name); +} - return IrArray::Index(linear_index, unreduced_output_shape, b); +std::tuple +KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) { + // Calculate (y, x) coordinate of the thread in the 2D view of thread block + // defined by (num_thread_y, num_thread_x) from thread_id. + llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm_ir::AddRangeMetadata(0, GetThreadsPerTile(), thread_id_raw); + llvm::Value* thread_id_int = + b_->CreateIntCast(thread_id_raw, index_ty, + /*isSigned=*/true, "thread.id.x"); + llvm::Value* num_thread_x = + llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX()); + llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x); + llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x); + return std::make_tuple(y, x); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index 5ea05b3188a1c0881e4c0c41625d530aff1b1205..06002d57b0d7daa07f903feebe67a60a083c0e7c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -28,23 +28,160 @@ namespace llvm_ir { // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical // components to 0-2-1. We call the shape being transposed the input shape and -// the transposed shape the output shape. The logical view of the input and -// output shapes for the transpose are called the 0-1-2 shape or reduced input -// shape and the 0-2-1 shape or the reduced output shape respectively. The -// original input and output shapes are called the unreduced input and output -// shapes. - +// the transposed shape the output shape. The logical view of the input/output +// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized +// shapes. The original input/output shapes are called unnormalized shapes. +// // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the -// reduced shape of `b` or the 0-2-1 shape. +// normalized shape of `b` or the 0-2-1 shape. absl::optional > FindTranspose021(const Shape& a, const Shape& b); -// Return the unreduced output index corresponding to the given reduced output -// index. -IrArray::Index GetUnreducedOutputIndex( - const IrArray::Index& reduced_output_index, - const Shape& reduced_output_shape, const Shape& unreduced_output_shape, - llvm::IRBuilder<>* b); +// A tile is a spatial subdivision of a tensor. We group tensor elements into +// tiles so that we can launch kernels to process the tensor elements in blocks +// of tiles. +// +// A kernel mapping scheme describes a method to partition the tensors accessed +// by an unnested HLO instruction into tiles and blocks of tiles, and the +// associated information to use hardware threads to process the tensor elements +// in blocks of tiles. +// +// Currently, there are two main use cases for a tiling scheme. First, we +// implement kernels with 0-2-1 memory transpose using shared memory to improve +// memory access pattern. Second, we implement reduction to contiguous +// dimensions in layout, with or without memory tranpsose, to achieve better +// memory access pattern as well as to reduce the need numbers of executed +// expensive instructions, such as thread synchronization related instructions +// and atomic operations. For both use cases, we can apply a normalization to +// the original tensors, to collapse contiguous dimensions for the same purpose +// and produce normlized three dimensional tensors. For this reason, the tiling +// scheme class only needs to handle normalized three dimensional tensors and +// two dimensional tiles. +// +// The current implementation of the class is somewhat NVIDIA GPU oriented. This +// situation can be improved when there is a need though. The idea of 0-2-1 +// transpose using shared memory can be found in the following CUDA algorithm in +// TensorFlow: https://goo.gl/MStRV6. +// +// We use a thread block to process a tile because we want to use the HW thread +// block synchronization primitives to synchronize the processing of all the +// elements in the same tile. A thread block can be viewed as a two dimensional +// array of threads, described by the number of threads for the Y and X +// dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of +// (tile_size_y, tile_size_x) as follows: each thread in the thread block +// processes one element in the tile so that all the threads in the thread block +// together process a subdivision of the tile that has the same dimension as the +// thread block array. Then the thread block moves on to process the next +// subdivision of the tile until the whole tile is processed. Therefore, each +// thread in the thread block processes +// tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile. +// +// There are situations where we want a thread block to process multiple +// tiles. We can't group those tiles into a bigger tiles because we limit a tile +// to a two dimensional spatial subdivision of a tensor. For example, when we +// use tiling to implement reduction with tranpose, we want the partial sum +// produced by each thread to accumulate values for more elements before using +// shlf_down and atomic_add instructions for further reduction, to amortize the +// cost of such expensive instructions. The concept of tile block is introduced +// for this purpose. A tile block is a three dimensional array of tiles, of +// which some dimensions may be degenerated to only one tile. +class KernelMappingScheme { + public: + enum { DimZ = 0, DimY, DimX, DimTot }; + + public: + // dims_in_elems: the normalized tensor dimensions. + // req_block_sizes: the requested block size in number of tiles for each + // dimension. The actual block size is set to min(req_block_size, + // dims_in_number_of_blocks). + explicit KernelMappingScheme(absl::Span dims_in_elems, + int64 tile_size_y, int64 tile_size_x, + absl::Span req_block_sizes, + int64 num_threads_y, int64 num_threads_x, + llvm::IRBuilder<>* b); + + absl::Span GetDimensionsInElements() const { + return dims_in_elems_; + } + absl::Span GetDimensionsInTiles() const { + return dims_in_tiles_; + } + absl::Span GetDimensionsInBlocks() const { + return dims_in_blocks_; + } + + int64 GetNumberOfTilesInTotal() const { + return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies()); + } + int64 GetNumberOfTilesInOneBlock() const { + return absl::c_accumulate(block_sizes_, 1, std::multiplies()); + } + + int64 GetNumberOfBlocks() const { + return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); + } + + int64 GetTileSizeForDimension(int d) const { + DCHECK(d >= DimZ && d <= DimX); + return tile_sizes_[d]; + } + int64 GetTileSizeForDimensionX() const { + return GetTileSizeForDimension(DimX); + } + int64 GetTileSizeForDimensionY() const { + return GetTileSizeForDimension(DimY); + } + + absl::Span GetBlockSizes() const { return block_sizes_; } + + int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } + int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; } + + int64 GetThreadsPerTile() const { + return GetNumberOfThreadsForDimensionX() * + GetNumberOfThreadsForDimensionY(); + } + + IrArray::Index EmitBlockIndex(llvm::Type* index_ty); + // Returns the index for the first tile in the block with the given block + // index. + IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index); + // Returns the index for the first element in the tile with the given tile + // index. + IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index); + + std::tuple EmitThreadYXCoordinate( + llvm::Type* index_ty); + + IrArray::Index GetUnnormalizedIndex( + const IrArray::Index& normalized_shape_index, + const Shape& unnormalized_shape); + + llvm::GlobalVariable* GetSharedMemoryBufferForElementType( + llvm::Type* elem_ty, absl::string_view buffer_name); + + private: + llvm::IRBuilder<>* b_; + // The number of elements in each dimension. + absl::Span dims_in_elems_; + + // The number of elements for each dimension of a tile. + std::vector tile_sizes_; + // The number of tiles in each dimension. It is computed from dims_in_elem_ + // and tile_sizes_. + std::vector dims_in_tiles_; + + // The number of tiles for each dimension of a tile block. + std::vector block_sizes_; + // The number of blocks in each dimension of a tile block. It is computed from + // dims_in_tile_ and block_sizes_. + std::vector dims_in_blocks_; + + // Number of threads used to process elements in the X direction of a tile. + int64 num_threads_x_; + // Number of threads used to process elements in the Y direction of a tile. + int64 num_threads_y_; +}; // A class to represent information for tiled parameters to support IR emission // for 021 transpose. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 1a53c026be340ca3bec3a49b11666d6124728130..ceea24685af566e02340664f0a40c398c62b5ab0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/byte_order.h" @@ -83,10 +84,9 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b) { +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -244,10 +244,11 @@ StatusOr EncodeSelfDescribingShapeConstant(const Shape& shape, StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, int32 size_bytes) { - Shape shape; - TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes)); + ShapeProto shape_proto; + TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes)); + Shape shape(shape_proto); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return shape; + return std::move(shape); } llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, @@ -260,6 +261,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, /*AddNull=*/false); } +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name) { + const int kNVPTXSharedMemoryAddrSpace = 3; + return new llvm::GlobalVariable( + *module, tile_type, + /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, + llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr, + llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace); +} + llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, absl::string_view name, llvm::IRBuilder<>* b, @@ -362,11 +374,10 @@ static void LogS64(const char* tag, int64 value) { void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) { llvm::FunctionType* log_function_type = llvm::FunctionType::get( b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false); - b->CreateCall( - log_function_type, - b->CreateIntToPtr(b->getInt64(tensorflow::bit_cast(&LogS64)), - log_function_type->getPointerTo()), - {b->getInt64(tensorflow::bit_cast(tag)), value}); + b->CreateCall(log_function_type, + b->CreateIntToPtr(b->getInt64(absl::bit_cast(&LogS64)), + log_function_type->getPointerTo()), + {b->getInt64(absl::bit_cast(tag)), value}); } void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index f59baff263fe7184c6b0821c9dbd9eee205586a6..c604c7c870adf734a29017e6accbd159317a9548 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -101,10 +102,9 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, - absl::Span operands, - absl::Span overloaded_types, - llvm::IRBuilder<>* b); +llvm::CallInst* EmitCallToIntrinsic( + llvm::Intrinsic::ID intrinsic_id, absl::Span operands, + absl::Span overloaded_types, llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -155,6 +155,11 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module); +// Allocates a tile of shared memory. +llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module, + llvm::Type* tile_type, + absl::string_view name); + // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point // of the builder is set to the same place after calling this function diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index 944c79580c133906cd431722fd6b29e6aee5f918..e22c2173c271fc9571be1ddb0759d2b31562dc98 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" +#include + // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -27,10 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -38,148 +43,365 @@ namespace xla { namespace llvm_ir { namespace { -// Adds the inner comparison loop where we compare elements pointed to by -// 'keys_index' and 'compare_keys_index'. -void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, - const IrArray::Index& compare_keys_index, - const IrArray& keys_array, - const absl::optional& values_array, - llvm::IRBuilder<>* b) { - // if (is_smaller_index && - // compare_keys[dimension_to_sort] < dimension_to_sort_bound) - llvm::Value* is_smaller_index = b->CreateICmpSLT( - keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - auto if_data = EmitIfThenElse( - b->CreateAnd(is_smaller_index, - b->CreateICmpSLT(compare_keys_index[dimension_to_sort], - keys_index.GetConstantWithIndexType( - dimension_to_sort_bound))), - "smaller_comparison_index", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_data.true_block, b); - auto key1 = keys_array.EmitReadArrayElement(keys_index, b); - auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); - auto compare_key1 = key1; - auto compare_key2 = key2; - auto key_type = keys_array.GetShape().element_type(); - bool is_signed_comparison = true; - if (primitive_util::IsFloatingPointType(key_type)) { - // We would like a total order of floating point numbers so that the sort - // has a predictable behavior in the presence of NaNs. Rather than using - // floating point comparison, we use the following trick: - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? 0x7FFFFFFF - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - auto k = b->getInt(llvm::APInt::getSignedMaxValue( - key1->getType()->getPrimitiveSizeInBits())); - auto comparison_type = k->getType(); - auto zero = llvm::ConstantInt::get(comparison_type, 0); - auto maybe_flip = [&](llvm::Value* v) { - return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), - b->CreateSub(k, v), v); - }; - compare_key1 = b->CreateBitCast(key1, comparison_type); - compare_key2 = b->CreateBitCast(key2, comparison_type); - compare_key1 = maybe_flip(compare_key1); - compare_key2 = maybe_flip(compare_key2); - } else if (!primitive_util::IsSignedIntegralType(key_type)) { - is_signed_comparison = false; + +// Adds the inner comparison loop body where we compare elements. +void EmitCompareLoopBody( + int64 iteration_bound, PrimitiveType key_type, int64 num_values, + int64 iota_values_parameter_index, llvm::Value* element_pair_index, + int64 xor_mask, llvm::Type* index_type, + std::function read_element, + std::function + write_element, + llvm::IRBuilder<>* b, bool needs_bounds_checks = true) { + auto index_typed_constant = [&](int64 value) { + return llvm::ConstantInt::get(index_type, value); + }; + // The 'xor_mask' determines which elements are compared against each other. + // Index 'current_keys_index' will be compared with 'current_keys_index' xor + // 'xor_mask'. This means that we will always compare a block of consecutive + // elements against elements from the adjacent block of the same size. When + // 'xor_mask' is a power of 2, it immediately identifies the size of such a + // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In + // that case, we essentially flip the last 'k' - 1 bits when computing the + // position of the element to compare to, so the block size is 2^(k - 1). + int64 block_size = xor_mask; + // Check if it is a value 2^k - 1. + if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) { + block_size = (xor_mask + 1) / 2; + } + auto current_keys_index = element_pair_index; + if (block_size == 1) { + // If the block size is 1, we take every second element and compare it to + // the next one. + current_keys_index = + b->CreateMul(current_keys_index, index_typed_constant(2)); + } else if (block_size * 2 < iteration_bound) { + // current_keys_index iterates through the 'left' elements of the element + // pairs to be compared. We first need to compute the comparison block to + // which the element belongs. The block id of that block is index / + // block_size. + auto block_id = + b->CreateUDiv(current_keys_index, index_typed_constant(block_size)); + // The index of the 'left' element within its block is simply the remainder + // when dividing by 'block_size'. + auto index_within_block = + b->CreateURem(current_keys_index, index_typed_constant(block_size)); + // The first element of the 'left' block of elements that is compared + // against elements from the adjacent 'right' block of elements is + // 'block_id' * (2 * 'block_size'). + auto first_element_in_block = + b->CreateMul(block_id, index_typed_constant(2 * block_size)); + current_keys_index = + b->CreateAdd(first_element_in_block, index_within_block); + } + auto compare_keys_index = + b->CreateXor(current_keys_index, index_typed_constant(xor_mask)); + // current_keys_index < compare_keys_index + llvm::Value* is_smaller_index = + b->CreateICmpSLT(current_keys_index, compare_keys_index); + // compare_keys_index < iteration_bound + llvm::Value* index_is_inbounds = b->CreateICmpSLT( + compare_keys_index, index_typed_constant(iteration_bound)); + llvm::Value* do_comparison = + needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds) + : b->getInt1(true); + + // if (is_smaller_index && index_is_inbounds) + KernelSupportLibrary ksl(b); + ksl.IfReturnVoid("smaller_comparison_index", do_comparison, [&]() { + auto key1 = read_element(0, current_keys_index); + auto key2 = read_element(0, compare_keys_index); + auto compare_key1 = key1; + auto compare_key2 = key2; + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the + // sort has a predictable behavior in the presence of NaNs. Rather + // than using floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the + // obvious order, -0 is ordered before 0, and -NaN and NaN appear at + // the beginning and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } + // If key2 < key1 + auto is_smaller_than = + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1); + if (iota_values_parameter_index >= 0) { + auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2); + auto key_index1 = + read_element(iota_values_parameter_index, current_keys_index); + auto key_index2 = + read_element(iota_values_parameter_index, compare_keys_index); + auto index_is_smaller_than = + b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1); + is_smaller_than = b->CreateOr( + is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than)); + } + ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() { + // Swap key1 with key2. + write_element(0, current_keys_index, key2); + write_element(0, compare_keys_index, key1); + for (int64 i = 1; i <= num_values; ++i) { + // Also swap the values. + auto value1 = read_element(i, current_keys_index); + auto value2 = read_element(i, compare_keys_index); + write_element(i, current_keys_index, value2); + write_element(i, compare_keys_index, value1); + } + }); + }); +} + +void EmitTiledCompareLoop( + const IrArray::Index& tiled_keys_index, int64 dimension_to_sort, + int64 dimension_to_sort_bound, PrimitiveType keys_type, + absl::Span xor_masks, const std::vector& params, + const std::vector& param_shmem_buffers, + int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) { + KernelSupportLibrary ksl(b); + llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b); + llvm_ir::AddRangeMetadata(0, tile_size / 2, + llvm::cast(thread_id)); + thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(), + /*isSigned=*/true, "thread.id.x"); + + auto copy_loop_body = + [&](std::function + read_or_write) { + auto value_one = tiled_keys_index.GetConstantWithIndexType(1); + auto current_keys_index = + b->CreateShl(tiled_keys_index[dimension_to_sort], value_one); + // We want to copy two adjacent elements. We first check whether the + // first index position is within bounds. + ksl.IfReturnVoid( + "smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + auto cache_index = b->CreateShl(thread_id, value_one); + read_or_write(cache_index, current_keys_index); + // Increment to go the next index position. + current_keys_index = b->CreateAdd(current_keys_index, value_one); + // Here we check whether the next index position is within bounds. + ksl.IfReturnVoid( + "inner_smaller_keys_index", + b->CreateICmpSLT(current_keys_index, + tiled_keys_index.GetConstantWithIndexType( + dimension_to_sort_bound)), + [&]() { + cache_index = b->CreateAdd(cache_index, value_one); + read_or_write(cache_index, current_keys_index); + }); + }); + }; + + // Copy operand tiles from the operand buffers to shared memory. + IrArray::Index keys_index = tiled_keys_index; + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = params[i].EmitReadArrayElement(keys_index, b); + b->CreateStore(value, + b->CreateGEP(param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), + cache_index})); + }); + } + // Wait until all reads have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); + + // Now emit the bodies of the comparison loops. + auto read_element = [&](int64 operand, llvm::Value* index) { + return b->CreateLoad( + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + b->CreateStore( + value, + b->CreateGEP(param_shmem_buffers[operand], + {tiled_keys_index.GetConstantWithIndexType(0), index})); + }; + for (int64 xor_mask : xor_masks) { + // The index of the element pair to be compared within the tile stored in + // shared memory. We order the element pairs by the element with the smaller + // index. + auto element_pair_index = thread_id; + // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't + // need any bounds checks. + if (dimension_to_sort_bound % tile_size) { + // Otherwise we need a bounds check for the last tile. The last tile has + // size 'dimension_to_sort_bound' % 'tile_size'. + ksl.IfReturnVoid( + "is_last_tile", + b->CreateICmpUGE( + b->CreateMul(tiled_keys_index[dimension_to_sort], + tiled_keys_index.GetConstantWithIndexType(2)), + tiled_keys_index.GetConstantWithIndexType( + RoundDownToNearest(dimension_to_sort_bound, tile_size))), + [&]() { + EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type, + params.size() - 1, iota_values_parameter_index, + element_pair_index, xor_mask, + tiled_keys_index.GetType(), read_element, + write_element, b); + }, + [&]() { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), + read_element, write_element, b, + /*needs_bounds_checks=*/false); + }); + } else { + EmitCompareLoopBody(tile_size, keys_type, params.size() - 1, + iota_values_parameter_index, element_pair_index, + xor_mask, tiled_keys_index.GetType(), read_element, + write_element, b, /*needs_bounds_checks=*/false); + } + // Wait until all comparisons have happened. + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b); } - auto comparison = - b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - compare_key2, compare_key1); - // If key2 < key1 - auto if_smaller_data = - EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); - SetToFirstInsertPoint(if_smaller_data.true_block, b); - // Swap key1 with key2. - keys_array.EmitWriteArrayElement(keys_index, key2, b); - keys_array.EmitWriteArrayElement(compare_keys_index, key1, b); - if (values_array.has_value()) { - // Also swap the values. - auto value1 = values_array.value().EmitReadArrayElement(keys_index, b); - auto value2 = - values_array.value().EmitReadArrayElement(compare_keys_index, b); - values_array.value().EmitWriteArrayElement(keys_index, value2, b); - values_array.value().EmitWriteArrayElement(compare_keys_index, value1, b); + + // Copy the operand tiles back from shared memory to the operand buffers. + for (int64 i = 0; i < params.size(); ++i) { + copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + auto value = b->CreateLoad(b->CreateGEP( + param_shmem_buffers[i], + {tiled_keys_index.GetConstantWithIndexType(0), cache_index})); + params[i].EmitWriteArrayElement(keys_index, value, b); + }); } + // We should normally synchronize here to make sure all writes have happened. + // However the very next thing each thread does is reading 2 elements from the + // operand buffer and writing it into the same location in shared memory from + // which it previously copied it to the operand buffer, and we synchronize + // after this has happened. We can be sure that a thread always writes to the + // same location in shared memory because we have exactly tile_size / 2 many + // threads, and the linear index calculated by ParallelLoopEmitter uses + // linear_index = blockIdx.x * blockDim.x + threadIdx.x; } } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions) { - const Shape& keys_shape = keys_array.GetShape(); + const std::vector& values_arrays, + int64 iota_values_parameter_index, + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, + const int64 tile_size) { + // Iterate through the keys shape in physical order, but skip the dimension to + // sort and make it the innermost loop which is the loop where the comparisons + // happen. In the dimension to sort, if we use tiling, we iterate through it + // in tiles of 64 elements each, so we use another loop that happens within + // one thread to process this tile worth of data (thereby combining several + // comparison stages of the bitonic sort algorithm because they all happen + // within those 64 elements and are therefore independent of the other + // comparisons). - // Create loop nests which loop through the operand dimensions. The sort - // dimension is handled in the innermost loop which performs the sorting. - ForLoopNest loop_nest(name, b); - IrArray::Index keys_index = - loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys"); - if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) { - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b); + const Shape& keys_shape = keys_array.GetShape(); + int64 rank = ShapeUtil::Rank(keys_shape); + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::vector dimensions_in_iteration_order(rank); + std::vector iteration_order_to_logical_order(rank); + int64 dim = 0; + for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) { + if (dimension != dimension_to_sort) { + dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension); + iteration_order_to_logical_order[dim++] = dimension; + } } + dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim; + iteration_order_to_logical_order[dim] = dimension_to_sort; - // 'compare_keys_index' is the index of the element that 'keys_index' should - // be compared to. - IrArray::Index compare_keys_index(keys_index.GetType()); - for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { - if (dimension != dimension_to_sort) { - compare_keys_index.push_back(keys_index[dimension]); - } else { - compare_keys_index.push_back(nullptr); + Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(), + dimensions_in_iteration_order); + std::vector params(1, keys_array); + params.insert(params.end(), values_arrays.begin(), values_arrays.end()); + + // Allocate shared memory for the tiled compare loop. + std::vector param_shmem_buffers(params.size(), nullptr); + if (xor_masks.size() > 1) { + llvm::Module* module = b->GetInsertBlock()->getParent()->getParent(); + for (int64 i = 0; i < params.size(); ++i) { + llvm::Type* tile_type = + llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType( + params[i].GetShape().element_type(), module), + tile_size); + param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile( + module, tile_type, absl::StrCat(name, "_tile_param_", i)); } } - // Naive C++ code for the inner compare loop: - // - // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { - // int64 j = i ^ xor_mask; - // if (i < j && j < dimension_to_sort_bound) { - // int64 min_key = std::min(keys[i], keys[j]); - // keys[j] = std::max(keys[i], keys[j]); - // keys[i] = min_key; - // } - // } - // - // This follows the algorithm described on Wikipedia: - // https://en.wikipedia.org/wiki/Bitonic_sorter - - int64 dimension_to_sort_bound = - keys_array.GetShape().dimensions(dimension_to_sort); - Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(), - {dimension_to_sort_bound}); auto compare_loop_body_emitter = - [&](const IrArray::Index& compare_index) -> Status { - keys_index[dimension_to_sort] = compare_index[0]; - compare_keys_index[dimension_to_sort] = - b->CreateXor(compare_index[0], xor_mask); - EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, - keys_array, values_array, b); + [&](const IrArray::Index& tiles_index) -> Status { + // Naive C++ code for the inner compare loop: + // + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ xor_mask; + // /* emitted in EmitCompareLoopBody() */ + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + IrArray::Index keys_index(tiles_index.GetType(), rank); + for (int64 i = 0; i < rank; ++i) { + keys_index[iteration_order_to_logical_order[i]] = tiles_index[i]; + } + if (xor_masks.size() > 1) { + EmitTiledCompareLoop(keys_index, dimension_to_sort, + dimension_to_sort_bound, keys_shape.element_type(), + xor_masks, params, param_shmem_buffers, + iota_values_parameter_index, tile_size, b); + } else { + auto read_element = [&](int64 operand, llvm::Value* index) { + keys_index[dimension_to_sort] = index; + return params[operand].EmitReadArrayElement(keys_index, b); + }; + auto write_element = [&](int64 operand, llvm::Value* index, + llvm::Value* value) { + keys_index[dimension_to_sort] = index; + params[operand].EmitWriteArrayElement(keys_index, value, b); + }; + EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(), + values_arrays.size(), iota_values_parameter_index, + tiles_index[rank - 1], xor_masks[0], + tiles_index.GetType(), read_element, write_element, + b); + } return Status::OK(); }; - if (launch_dimensions != nullptr) { - TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter, - compare_shape, - *launch_dimensions, b) - .EmitLoop(name)); - } else { - TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b) - .EmitLoop(name)); - } - - // Set the IR builder insert point to the exit basic block of the outer most - // loop. This ensures later instructions are inserted after this loop nest. - b->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - - return Status::OK(); + return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape, + launch_dimensions, b) + .EmitLoop(name); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 527ed10374ce9482045a8459e38fd041e0e83001..685f9383acba416f51681270e4037d56abb4b6ea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include + #include "absl/strings/string_view.h" -#include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -28,13 +30,17 @@ namespace xla { namespace llvm_ir { // Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort' // dimension of 'keys_array'. All other dimensions are kept as-is. This -// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, -// the inner compare loop will not be parallelized. +// implements the inner loop of BitonicSort. It is assumed that 'xor_masks' +// contains only powers of 2, or values 2^k - 1 (k > 0). If +// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand +// that is a iota and can be used to make the sorting stable. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const absl::optional& values_array, - absl::string_view name, llvm::Value* xor_mask, - llvm::IRBuilder<>* b, - const gpu::LaunchDimensions* launch_dimensions); + const std::vector& values_arrays, + int64 iota_values_parameter_index, + absl::string_view name, + absl::Span xor_masks, llvm::IRBuilder<>* b, + const gpu::LaunchDimensions& launch_dimensions, + int64 num_iterations_in_sort_dim, int64 tile_size); } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0d0fb7946ae6815905491ca55652d7d0ab278a3c..6c89700983363fec46c41b5430c6eab6b366a1b6 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -96,44 +96,18 @@ ExecutionOptions CreateExecutionOptions( const ExecutableBuildOptions& build_options, const ProgramShape* program_shape) { ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.hlo_profile().has_value()) { - execution_options.mutable_debug_options()->set_xla_hlo_profile( - *build_options.hlo_profile()); - } - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); - } - if (build_options.dump_optimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_optimized_hlo_proto_to( - build_options.dump_optimized_hlo_proto_to().value()); - } - if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_unoptimized_hlo_proto_to( - build_options.dump_unoptimized_hlo_proto_to().value()); - } - if (build_options.dump_per_pass_hlo_proto_to().has_value()) { - execution_options.mutable_debug_options() - ->set_xla_dump_per_pass_hlo_proto_to( - build_options.dump_per_pass_hlo_proto_to().value()); + if (build_options.has_debug_options()) { + *execution_options.mutable_debug_options() = build_options.debug_options(); } if (build_options.result_layout() != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); + build_options.result_layout()->ToProto(); } else { + Shape result_shape(program_shape->result()); + LayoutUtil::SetToDefaultLayout(&result_shape); *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + result_shape.ToProto(); } - - for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { - execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( - disabled_pass); - } - return execution_options; } @@ -144,8 +118,8 @@ StatusOr> LocalService::CompileExecutable( const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); - TF_RET_CHECK(proto.has_program_shape()); - const ProgramShape& program_shape = proto.program_shape(); + TF_RET_CHECK(proto.has_host_program_shape()); + ProgramShape program_shape(proto.host_program_shape()); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { @@ -220,4 +194,10 @@ StatusOr LocalService::GlobalDataToShapedBuffer( return buffers[replica_number]; } +StatusOr LocalService::RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag) { + return allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), tag); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 3b4f0b50832d6d2b64528ffb63eb5c7375396aec..f56ba32b04b9bf3aba75654bdb98887ad22e6791 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -63,6 +63,11 @@ class LocalService : public Service { StatusOr GlobalDataToShapedBuffer( const GlobalDataHandle& data, int replica_number); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector replicated_buffers, const string& tag); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index ec52a24d782a44fda961feab3230886072e755c7..972a5b9ced0d84387ef8308efe2a7aff7317d047 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -113,6 +113,13 @@ Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand and does not + // create buffers. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) { // The top-level buffer (index={}) for kCopy is newly created, but all other // buffers (in the case of a tuple shape) come from the operand diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 81f524d84a8091e1fff13dc7c55b401143a02753..7ffca943d0f7805ad4420343fcdbf860415c4c40 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -64,6 +64,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/map_inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 84059dd0f71ee8fc0a25703cbab2268d7dc149a8..fd18bfdc3e7f4b5f94237c554c3e6ca8bd065a35 100644 --- a/tensorflow/compiler/xla/service/map_inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using MapInlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(MapInlinerTest, MapMax) { @@ -59,12 +59,12 @@ TEST_F(MapInlinerTest, MapMax) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -93,12 +93,12 @@ TEST_F(MapInlinerTest, MapConstant) { HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -131,12 +131,12 @@ TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); auto computation = builder.Build(); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewVerifiedModule(); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); MapInliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 2ca527bc4cb8f66a085c1e6a7cbb8ddaedbfc07e..9ccdd7d8d818b9fa3aa77cdd10d37ca18928b448 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/types.h" @@ -257,7 +258,7 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, } void MultiOutputFusion::RecomputeReachability() { - reachability_ = computation_->ComputeReachability(); + reachability_ = HloReachabilityMap::Build(computation_); } void MultiOutputFusion::UpdateReachability( @@ -317,9 +318,9 @@ bool MultiOutputFusion::Perform() { << instr2->fused_instructions_computation()->ToString( HloPrintOptions().set_indent_amount(1)); } + Update(instr1, instr2); HloInstruction* ret = Fuse(instr1, instr2); set_is_fused(ret == instr1 ? instr2 : instr1); - Update(instr1, instr2); changed = true; VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" << ret->fused_instructions_computation()->ToString( diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 9508ab2ed1d38ec40983d8892ec8875b848fb21b..1c7583ece720f9e4d4b71a6279b976fed40e10cb 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 380cde0e6a858c7800445be94bb08dc22f3e776a..c35f72699bfe90f7b8021916c0f81d5e1926ff4c 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -44,32 +45,48 @@ namespace xla { // // This pattern will match Add instructions whose first operand is a constant. // -// Each pattern type has the following modifiers: +// Each pattern type has the following modifiers, which are described where +// nontrivial. // // Op(): -// - WithName: match operations with the given name -// - WithOpcode: match operations with the given opcode -// - WithShape: match operations whose shape matches the given pattern -// - WithOperand: match operations whose operand matches the given pattern +// - Is: is the given HloInstruction* (i.e. pointer equality) +// - WithName +// - WithOpcode +// - WithoutOpcode: anything other than the given opcode +// - WithShape: instr's shape matches the given pattern +// - WithShapeEqualTo: instr's shape is equal to the given Shape +// - WithShapeCompatibleTo: instr's shape is compatible with the given Shape +// - WithNumOperands +// - WithOperand: operand at the given index matches the given pattern +// - IsConstant +// - IsNonConstant +// - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value, +// e.g. IsConstantScalar() or IsConstantScalar(42). +// - WithFusionKind +// - WithTupleIndex: get-tuple-element operations with the given tuple index +// - WithOneUse: Instruction is used as an operand exactly once. +// - WithOneUser: Instruction is used by exactly one other instruction, but +// is possibly used more than once as an operand (e.g. multiply(x,x)). // // Shape(): -// - EqualTo: matches shapes that are equal to the argument -// - CompatibleTo: matches shapes that are compatible to the argument -// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes -// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format -// - WithLayout: match shapes whose layout matches the given pattern -// - WithLayoutEqualTo: matches shapes whose layouts equal the argument -// - WithSubshape: matches tuple shapes whose subshape matches the given -// pattern -// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument -// - WithElementType: matches array/scalar shapes with the given element -// type -// - WithRank: matches array/scalar types with the given rank +// - EqualTo +// - CompatibleTo +// - IsScalar/IsEffectiveScalar/IsArray/IsTuple +// - IsDenseArray/IsSparseArray +// - WithLayout: layout shape's layout matches the given pattern (e.g. +// Layout().WithDenseFormat()) +// - WithLayoutEqualTo: shape's layout equals the argument (i.e. another +// Layout, but not the result of Layout().foo()) +// - WithSubshape: shape is a tuple whose subshape matches the given pattern +// (e.g. Shape().IsScalar()). +// - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg +// (i.e. another Shape, but not the result of Shape().foo()) +// - WithElementType: shape is an array/scalar with the given elem type +// - WithRank: shape is an array/scalar with the given rank // // Layout(): -// - EqualTo: matches layouts that are equal to the argument -// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse -// format +// - EqualTo +// - WithDenseFormat/WithSparseFormat // // Op(), Shape(), and Layout() may be passed an argument of type // HloInstruction**, Shape**, or Layout**, respectively, or const versions of @@ -82,53 +99,55 @@ namespace xla { // CHECK(Match(foo, // match::Op().WithOperand(0, match::Op(&matched_operand)))); // -// Helpers are provided for common nullary, unary, binary, and ternary -// instructions. These helpers can be called with no arguments, in which case -// they will match any instruction matching the opcode. They may also be called -// with matches for the operands and with an optional capture. (The capture must -// be the first argument.) Some examples of these helpers and their equivalents -// are provided below. -// +// Helpers are provided for most HLO instructions. These helpers can be called +// with no arguments, in which case they will match any instruction matching the +// opcode. They may also be called with matches for the operands and with an +// optional capture. (The capture must be the first argument.) Some examples of +// these helpers and their equivalents are provided below. + // Example nullary instruction: -// Param() == Op().WithOpcode(HloOpcode::kParam) -// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam) +// Parameter() == Op().WithOpcode(HloOpcode::kParameter) +// Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter) // // Example unary instruction: -// Abs() == Op().WithOpcode(HloOpcode::kAbs) -// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&a))) -// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) -// .WithOperand(0, Op(&b)) +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) +// +// Commutative binary instructions have a special form that accepts either order +// of args, e.g.: +// +// AddAnyOrder(Parameter(1), Abs()) == +// Op().WithOpcode(HloOpcode::kAdd) +// .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs()); // -// Example binary instruction: -// Add() == Op().WithOpcode(HloOpcode::kAdd) -// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) +// MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`. // -// Example ternary instruction: -// Clamp() == Op().WithOpcode(HloOpcode::kClamp) -// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&a)) -// .WithOperand(1, Op(&b)) -// .WithOperand(2, Op(&c)) -// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) -// .WithOperand(0, Op(&b)) -// .WithOperand(1, Op(&c)) -// .WithOperand(2, Op(&d)) +// The following additional helpers are provided. In all cases, `&a` is +// optional. // +// ConstantScalar(&a) == Op(&a).IsConstantScalar(); +// ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v); +// ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar(); +// ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v) +// NonConstant(&a) == Op(&a).IsNonConstant() +// GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index) +// .WithOperand(0, b); +// Parameter(&a, n) == Op(&a).WithParameterNum(n); struct MatchOption { // If true, actually capture matched item into the user pointer. bool capture; + + // An explanation for why we failed to match is streamed here, if not-null. + std::ostream* explain_os; }; template bool Match(Value* value, const Pattern& pattern, - MatchOption option = {/*.capture=*/true}) { + MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) { if (option.capture) { auto new_option = option; new_option.capture = false; @@ -143,6 +162,77 @@ namespace match { namespace detail { +// Macro for streaming to option.explain_os if it's not null. +// +// EXPLAIN << "value of foo(): " << foo() +// +#pragma push_macro("EXPLAIN") +#define EXPLAIN \ + if (option.explain_os) *option.explain_os + +// kIndentInc is the additional number of spaces that we indent by when we +// increase the indent "by one". +enum { + kIndentInc = 2, +}; + +// Writes a newline and then `indent` spaces. +// +// We follow an unintuitive convention in this file's pretty-printers: Indents +// are performed by the caller, not the callee. For example, if you want to +// print +// +// foo: +// - bar +// +// you'd do: +// +// Foo::DescribeTo(std::ostream* os, int64 indent) { +// *os << "foo:"; +// Indent(os, indent) // Create a newline at the *current* indent level. +// *os << " - "; +// bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3. +// } +// +// Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; } +// +// Notice that Bar::DescribeTo() does not call Indent; the indenting is +// performed by Foo. This convention allows the caller to decide whether a +// matcher is preceded by a newline, which is important e.g. for the AllOf +// matcher. +// +// (Incidentally, indenting in Match's explanations is handled differently. +// Indents are a common case in DescribeTo [we're printing a whole tree], but +// they're a special case in Match [we're printing only a path through the tree +// that encounters a failing node]. Indents in Match only appear when we +// encounter a failing disjunction, so we just handle them as a special case +// there.) +inline void Indent(std::ostream* os, int64 indent) { + *os << "\n"; + for (int64 i = 0; i < indent; ++i) { + *os << " "; + } +} + +// SFINAE template that determines whether T declares a static member +// kIsTrivialMatcher. +// +// Trivial matchers get special treatment. For example, when printing +// a conjunction of matchers, we don't print "and" after a trivial matcher. This +// yields e.g. +// "a shape compatible with f32[1,2]" +// rather than +// "a shape AND compatible with f32[1,2]" +template +struct IsTrivialMatcher { + static constexpr bool value = false; +}; +template +struct IsTrivialMatcher::type> { + static constexpr bool value = true; +}; + template class AllOfPattern { public: @@ -162,10 +252,19 @@ class AllOfPattern { return matched; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + DescribeToImpl(os, std::integral_constant(), indent); + } + + // Accessor for patterns_. Please don't use this outside of this file. + const std::tuple& patterns() const { return patterns_; } + private: template bool MatchImpl(ItemType* item, MatchOption option, std::integral_constant) const { + // We don't need to do any EXPLAINing here; it's all correctly handled by + // our sub-matchers (if any fail). return std::get(patterns_).Match(item, option) && MatchImpl(item, option, std::integral_constant()); } @@ -176,6 +275,73 @@ class AllOfPattern { return true; } + // Pretty-printing a conjunction has some special cases to make it easy to + // read in the simple (common) case. + // + // If sizeof...(Patterns) == 1, prints as e.g. + // + // a shape + // + // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a + // shape") prints as + // + // a shape compatible with f32[1,2] + // + // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as + // + // a shape: + // * compatible with f32[1,2] AND + // * that represents a scalar + // + // Otherwise prints as: + // + // all of: + // * foo AND + // * bar + // + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + constexpr bool first_is_trivial = + IsTrivialMatcher(patterns_))>::type>::value; + constexpr bool is_last = index == sizeof...(Patterns) - 1; + const auto& submatcher = std::get(patterns_); + + auto print_bulleted_item = [&] { + *os << " * "; + submatcher.DescribeTo(os, indent + 3); + if (!is_last) { + *os << " AND"; + Indent(os, indent); + } + }; + + if (index == 0) { + if (first_is_trivial || is_last) { + submatcher.DescribeTo(os, indent + kIndentInc); + if (sizeof...(Patterns) > 2) { + *os << ":"; + Indent(os, indent); + } + } else { + *os << "all of:"; + Indent(os, indent); + print_bulleted_item(); + } + } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) { + *os << " "; + submatcher.DescribeTo(os, indent); + } else { + print_bulleted_item(); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -183,10 +349,6 @@ class AllOfPattern { // Returns a pattern that represents the conjunction of all input patterns. All // patterns need to match in order to have the AllOf pattern match. -// -// TODO(timshen): Currently AllOf is still nested, e.g. AllOf, B> is -// not AllOf. We might want to flatten the AllOf type structure if the -// C++ compile error message gets annoying. template detail::AllOfPattern::type, Patterns...> AllOf( const Patterns&... patterns) { @@ -194,6 +356,25 @@ detail::AllOfPattern::type, Patterns...> AllOf( Patterns...>(patterns...); } +// AllOf, X, Y, ...> => AllOf. +// +// This transformation is necessary for good pretty-printing. +template +detail::AllOfPattern::type, InnerPs..., + OuterPs...> +AllOf(const detail::AllOfPattern& inner_p, + const OuterPs&... outer_ps) { + // Invoke constructor of AllOfPattern. + auto make_all_of = [](const InnerPs&... inner_ps, + const OuterPs&... outer_ps) { + return detail::AllOfPattern::type, + InnerPs..., OuterPs...>(inner_ps..., + outer_ps...); + }; + return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(), + std::make_tuple(outer_ps...))); +} + namespace detail { template @@ -204,8 +385,18 @@ class LayoutPattern; class LayoutPatternBaseImpl { public: bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout != nullptr; + if (layout == nullptr) { + EXPLAIN << "Layout is null"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a layout"; } + + static constexpr bool kIsTrivialMatcher = true; }; // A LayoutPattern implementation that matches only if the layout equals a @@ -216,7 +407,17 @@ class LayoutPatternEqualImpl { : layout_(layout) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return LayoutUtil::Equal(*layout_, *layout); + if (!LayoutUtil::Equal(*layout_, *layout)) { + EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout) + << " is not equal to expected " + << LayoutUtil::HumanString(*layout_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << LayoutUtil::HumanString(*layout_); } private: @@ -230,7 +431,16 @@ class LayoutPatternFormatImpl { explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} bool Match(const ::xla::Layout* layout, MatchOption option) const { - return layout->format() == format_; + if (layout->format() != format_) { + EXPLAIN << "Layout has format " << Format_Name(layout->format()) + << " but expected " << Format_Name(format_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with format " << Format_Name(format_); } private: @@ -242,11 +452,13 @@ template class LayoutPattern { private: template - LayoutPattern> - AppendImpl(NewImpl new_impl) const { - return LayoutPattern>( - AllOf(impl_, std::move(new_impl)), matched_layout_); + auto AppendImpl(NewImpl new_impl) const + -> LayoutPattern(std::declval(), + std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return LayoutPattern(std::move(new_allof), + matched_layout_); } public: @@ -276,6 +488,10 @@ class LayoutPattern { return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Layout* layout) const @@ -306,19 +522,48 @@ class AnyOfPattern { explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} bool Match(const Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); } bool Match(Item* item, MatchOption option) const { - return MatchImpl(item, option, std::integral_constant()); + return MatchImpl(item, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "any of:"; + Indent(os, indent); + DescribeToImpl(os, std::integral_constant(), indent); } private: + template + bool MatchImpl(ItemType* item, MatchOption option) const { + // If we're generating an explanation, buffer it until we know we failed. + absl::optional explanation; + MatchOption new_option = option; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + bool rv = MatchRecursiveImpl(item, new_option, + std::integral_constant()); + if (!rv && option.explain_os) { + EXPLAIN << "None of the following matchers succeeded:"; + EXPLAIN << explanation->str(); + } + return rv; + } + template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl(ItemType* item, MatchOption option, + std::integral_constant) const { auto new_option = option; new_option.capture = false; + + absl::optional explanation; + if (option.explain_os) { + new_option.explain_os = &explanation.emplace(); + } + // Try to match the sub-pattern without capturing behavior. if (std::get(patterns_).Match(item, new_option)) { // Capture the branch. @@ -337,20 +582,46 @@ class AnyOfPattern { // AnyOf will be a runtime number indicate which sub-pattern is matched. // Then we run another pass to do captures only with the help of the // trace. - bool ret = std::get(patterns_).Match(item, option); - DCHECK(ret); + bool matched = std::get(patterns_).Match(item, option); + DCHECK(matched); } return true; } - return MatchImpl(item, option, std::integral_constant()); + if (option.explain_os) { + EXPLAIN << "\nMatcher #" << index + 1; + EXPLAIN << "\n - "; + std::get(patterns_).DescribeTo(option.explain_os, /*indent=*/3); + EXPLAIN << "\nfailed with"; + EXPLAIN << "\n - "; + EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}}); + } + return MatchRecursiveImpl(item, option, + std::integral_constant()); } template - bool MatchImpl(ItemType* item, MatchOption option, - std::integral_constant) const { + bool MatchRecursiveImpl( + ItemType* item, MatchOption option, + std::integral_constant) const { return false; } + template + void DescribeToImpl(std::ostream* os, std::integral_constant, + int64 indent) const { + *os << " - "; + std::get(patterns_).DescribeTo(os, indent + 3); + if (index != sizeof...(Patterns) - 1) { + *os << " OR"; + Indent(os, indent); + } + DescribeToImpl(os, std::integral_constant(), indent); + } + + void DescribeToImpl(std::ostream* os, + std::integral_constant, + int64 indent) const {} + std::tuple patterns_; }; @@ -395,8 +666,17 @@ class ShapePattern; class ShapePatternBaseImpl { public: bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (shape == nullptr) { + EXPLAIN << "Shape is null"; + } return shape != nullptr; } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "a shape"; + } + + static constexpr bool kIsTrivialMatcher = true; }; // A ShapePattern implementation that matches only if the shape equals a Shape @@ -407,7 +687,16 @@ class ShapePatternEqualImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Equal(*shape_, *shape); + if (!ShapeUtil::Equal(*shape_, *shape)) { + EXPLAIN << "Shape not equal to " + << ShapeUtil::HumanStringWithLayout(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_); } private: @@ -422,7 +711,16 @@ class ShapePatternCompatibleImpl { : shape_(shape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Compatible(*shape_, *shape); + if (!ShapeUtil::Compatible(*shape_, *shape)) { + EXPLAIN << "Shape not compatible with " + << ShapeUtil::HumanString(*shape_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "compatible with " << ShapeUtil::HumanString(*shape_); } private: @@ -437,7 +735,16 @@ class ShapePatternElementTypeImpl { : element_type_(element_type) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return shape->element_type() == element_type_; + if (shape->element_type() != element_type_) { + EXPLAIN << "Shape does not have element type " + << PrimitiveType_Name(element_type_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with element type " << PrimitiveType_Name(element_type_); } private: @@ -450,7 +757,15 @@ class ShapePatternIsScalarImpl { explicit constexpr ShapePatternIsScalarImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsScalar(*shape); + if (!ShapeUtil::IsScalar(*shape)) { + EXPLAIN << "Shape is not a scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a scalar"; } }; @@ -460,7 +775,15 @@ class ShapePatternIsArrayImpl { explicit constexpr ShapePatternIsArrayImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsArray(*shape); + if (!ShapeUtil::IsArray(*shape)) { + EXPLAIN << "Shape is not an array"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents an array"; } }; @@ -470,7 +793,34 @@ class ShapePatternIsTupleImpl { explicit constexpr ShapePatternIsTupleImpl() {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IsTuple(*shape); + if (!ShapeUtil::IsTuple(*shape)) { + EXPLAIN << "Shape is not a tuple"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that represents a tuple"; + } +}; + +// A ShapePattern implementation that matches only if the shape is an effective +// scalar. +class ShapePatternEffectiveScalarImpl { + public: + explicit constexpr ShapePatternEffectiveScalarImpl() {} + + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (!ShapeUtil::IsEffectiveScalar(*shape)) { + EXPLAIN << "Shape is not an effective scalar"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "that is an effective scalar"; } }; @@ -481,7 +831,23 @@ class ShapePatternRankImpl { explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::Rank(*shape) == rank_; + if (ShapeUtil::Rank(*shape) != rank_) { + if (rank_ == 0) { + EXPLAIN << "Shape is not a scalar"; + } else { + EXPLAIN << "Shape does not have rank " << rank_; + } + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (rank_ == 0) { + *os << "that is a scalar"; + } else { + *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : ""); + } } private: @@ -503,8 +869,21 @@ class ShapePatternLayoutImpl { } bool Match(Shape* shape, MatchOption option) const { - return LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout(), option); + if (!LayoutUtil::HasLayout(*shape)) { + EXPLAIN << "Shape does not have a layout"; + return false; + } + if (!layout_.Match(shape->mutable_layout(), option)) { + EXPLAIN << "\nin layout"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with"; + Indent(os, indent + kIndentInc); + layout_.DescribeTo(os, indent + kIndentInc); } private: @@ -522,17 +901,40 @@ class ShapePatternSubshapeImpl { : index_(index), subshape_(subshape) {} bool Match(const ::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); + return MatchImpl(shape, option); } bool Match(::xla::Shape* shape, MatchOption option) const { - return ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), - option); + return MatchImpl(shape, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with subshape at index " << index_.ToString() << " which is"; + Indent(os, indent + kIndentInc); + subshape_.DescribeTo(os, indent + kIndentInc); } private: + Shape* GetSubshape(Shape* shape) const { + return ShapeUtil::GetMutableSubshape(shape, index_); + } + const Shape* GetSubshape(const Shape* shape) const { + return &ShapeUtil::GetSubshape(*shape, index_); + } + + template + bool MatchImpl(ShapeType* shape, MatchOption option) const { + if (!ShapeUtil::IndexIsValid(*shape, index_)) { + EXPLAIN << "No subshape at " << index_.ToString(); + return false; + } + if (!subshape_.Match(GetSubshape(shape), option)) { + EXPLAIN << "\nin subshape at " << index_.ToString(); + return false; + } + return true; + } + ShapeIndexView index_; ShapePattern subshape_; }; @@ -542,10 +944,12 @@ template class ShapePattern { private: template - ShapePattern> AppendImpl( - NewImpl new_impl) const { - return ShapePattern>( - AllOf(impl_, std::move(new_impl)), matched_shape_); + auto AppendImpl(NewImpl new_impl) const + -> ShapePattern(std::declval(), + std::move(new_impl)))> { + auto new_all_of = AllOf(impl_, std::move(new_impl)); + return ShapePattern(std::move(new_all_of), + matched_shape_); } public: @@ -560,6 +964,11 @@ class ShapePattern { } return true; } + if (shape) { + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); + } return false; } @@ -571,9 +980,16 @@ class ShapePattern { } return true; } + EXPLAIN << "\nin " + << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape) + : ShapeUtil::HumanString(*shape)); return false; } + void DescribeTo(std::ostream* os, int64 indent = 0) const { + return impl_.DescribeTo(os, indent); + } + // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. constexpr auto EqualTo(const ::xla::Shape* shape) const @@ -612,6 +1028,11 @@ class ShapePattern { return AppendImpl(ShapePatternIsTupleImpl()); } + constexpr auto IsEffectiveScalar() const + -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { + return AppendImpl(ShapePatternEffectiveScalarImpl()); + } + // Modifies the pattern to match only if the shape has the given rank. constexpr auto WithRank(int64 rank) const -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { @@ -706,6 +1127,22 @@ Shape(::xla::Shape** matched_shape) { namespace detail { +// Overloads to get a const or non-const operand out of an instruction. +inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) { + return instr->mutable_operand(idx); +} +inline const HloInstruction* HloOperand(const HloInstruction* instr, + int64 idx) { + return instr->operand(idx); +} + +// Pretty-printer for HloInstruction. Sort of like ToShortString, but with +// fewer %s and more shapes. +inline string InstToString(const HloInstruction* inst) { + return inst->ToString( + HloPrintOptions().set_print_metadata(false).set_print_percent(false)); +} + template class HloInstructionPattern; @@ -714,8 +1151,18 @@ class HloInstructionPattern; class HloInstructionPatternBaseImpl { public: bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst != nullptr; + if (inst == nullptr) { + EXPLAIN << "HloInstruction* is null"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "an HloInstruction"; } + + static constexpr bool kIsTrivialMatcher = true; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -726,13 +1173,44 @@ class HloInstructionPatternNameImpl { : name_(name) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->name() == name_; + if (inst->name() != name_) { + EXPLAIN << "HloInstruction not named \"" << name_ << "\""; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "named \"" << name_ << "\""; } private: absl::string_view name_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// equals a particular pointer. +class HloInstructionIsImpl { + public: + explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst != inst_) { + EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " (" + << InstToString(inst_) << ")"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is " << inst_ << " (" << InstToString(inst_) << ")"; + } + + private: + const HloInstruction* inst_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. class HloInstructionPatternOpcodeImpl { @@ -742,7 +1220,25 @@ class HloInstructionPatternOpcodeImpl { : opcode_(opcode), invert_(invert) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return (invert_ ^ (inst->opcode() == opcode_)); + if (invert_ && inst->opcode() == opcode_) { + EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_) + << ", expected anything else"; + return false; + } + if (!invert_ && inst->opcode() != opcode_) { + EXPLAIN << "HloInstruction doesn't have opcode " + << HloOpcodeString(opcode_); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + if (!invert_) { + *os << "with opcode " << HloOpcodeString(opcode_); + } else { + *os << "with any opcode other than " << HloOpcodeString(opcode_); + } } private: @@ -750,6 +1246,30 @@ class HloInstructionPatternOpcodeImpl { bool invert_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// has the given number of operands. +class HloInstructionPatternNumOperandsImpl { + public: + explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands) + : num_operands_(num_operands) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (inst->operand_count() != num_operands_) { + EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with " << num_operands_ << " operand" + << (num_operands_ != 1 ? "s" : ""); + } + + private: + int64 num_operands_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // has a shape that matches a given pattern. template @@ -760,11 +1280,25 @@ class HloInstructionPatternShapeImpl { : shape_(shape) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(&inst->shape(), option); + if (!shape_.Match(&inst->shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return shape_.Match(inst->mutable_shape(), option); + if (!shape_.Match(inst->mutable_shape(), option)) { + EXPLAIN << "\nin output shape"; + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "outputting"; + Indent(os, indent + kIndentInc); + shape_.DescribeTo(os, indent + kIndentInc); } private: @@ -782,20 +1316,197 @@ class HloInstructionPatternOperandImpl { : operand_index_(operand_index), operand_(operand) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_), option); + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_), option); + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with operand " << operand_index_ << " which is:"; + Indent(os, indent + kIndentInc); + operand_.DescribeTo(os, indent + kIndentInc); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (operand_index_ >= inst->operand_count()) { + EXPLAIN << "desired operand index " << operand_index_ + << " is out of bounds"; + return false; + } + if (!operand_.Match(HloOperand(inst, operand_index_), option)) { + EXPLAIN << "\nin operand " << operand_index_; + return false; + } + return true; + } + int64 operand_index_; HloInstructionPattern operand_; }; +// Matches a binary instruction whose operands come in any order. +template +class HloInstructionPatternBinaryOperandsAnyOrderImpl { + public: + explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) + : op1_(op1), op2_(op2) {} + + bool Match(HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with two operands in either order:"; + Indent(os, indent); + *os << " - "; + op1_.DescribeTo(os, indent + 3); + Indent(os, indent); + *os << " - "; + op2_.DescribeTo(os, indent + 3); + } + + private: + HloInstruction* operand(HloInstruction* inst, int64 idx) const { + return inst->mutable_operand(idx); + } + const HloInstruction* operand(const HloInstruction* inst, int64 idx) const { + return inst->operand(idx); + } + + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + // We could implement this using AnyOf and AllOf matchers, but the templates + // get pretty difficult to debug, since any compile error herein becomes + // not-an-error via SFINAE. Also this way lets us give better messages on + // failure. + if (inst->operand_count() != 2) { + EXPLAIN << "HloInstruction did not have two operands"; + return false; + } + + // If we're not generating explanations, this is pretty simple. + if (!option.explain_os) { + auto try_match = [&](int64 idx1, int64 idx2) { + MatchOption new_option = option; + new_option.capture = false; + if (op1_.Match(operand(inst, idx1), new_option) && + op2_.Match(operand(inst, idx2), new_option)) { + if (option.capture) { + bool matched = op1_.Match(operand(inst, idx1), option) && + op2_.Match(operand(inst, idx2), option); + DCHECK(matched); + } + return true; + } + return false; + }; + return try_match(0, 1) || try_match(1, 0); + } + + // If we are generating explanations, we have some work to do in order to + // generate a helpful error. + // + // First, try all four operand/matcher combinations, recording the + // failure explanations separately from option.explain_os. matches[i][j] + // tells us if matcher_i matches operand j. + bool matches[/*matcher*/ 2][/*operand*/ 2]; + std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2]; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + MatchOption new_option = option; + new_option.capture = false; + new_option.explain_os = &explanations[i][j]; + matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option) + : op2_.Match(operand(inst, j), new_option); + } + } + + // Check if the match succeeded. + for (int i = 0; i < 2; ++i) { + if (matches[0][i] && matches[1][(i + 1) % 2]) { + // Rerun the matches with capture enabled if necessary. + if (option.capture) { + auto* operand1 = operand(inst, i); + auto* operand2 = operand(inst, (i + 1) % 2); + bool matched = + op1_.Match(operand1, option) && op2_.Match(operand2, option); + DCHECK(matched); + } + return true; + } + } + + auto describe_matcher = [&](int matcher_idx) { + EXPLAIN << "\n - "; + if (matcher_idx == 0) { + op1_.DescribeTo(option.explain_os, /*indent=*/3); + } else { + CHECK_EQ(matcher_idx, 1); + op2_.DescribeTo(option.explain_os, /*indent=*/3); + } + for (int i = 0; i < 2; ++i) { + if (matches[matcher_idx][/*operand*/ i]) { + continue; + } + EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n"; + EXPLAIN << " - "; + EXPLAIN << absl::StrReplaceAll( + explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}}); + } + }; + + // If we failed to match, one of the following is true: + // 1. op1 (op2) matches neither LHS nor RHS, or + // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS). + // We print different explanations depending on which case we're in. + + // Case 1. + bool wrote_explanation = false; + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (!matches[i][0] && !matches[i][1]) { + EXPLAIN << "HloInstruction's operands (ignoring order) did not match " + << (i == 0 ? "first" : "second") << " matcher. Specifically,"; + describe_matcher(i); + wrote_explanation = true; + } + } + + // Case 2. + for (int i = 0; !wrote_explanation && i < 2; ++i) { + if (matches[/*matcher*/ 0][/*operand*/ i] && + matches[/*matcher*/ 1][/*operand*/ i]) { + CHECK(!matches[0][(i + 1) % 2]); + CHECK(!matches[1][(i + 1) % 2]); + CHECK(!wrote_explanation); + EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS") + << " operand did not match either of the two matchers. " + "Specifically,"; + describe_matcher(0); + EXPLAIN << "\nand"; + describe_matcher(1); + wrote_explanation = true; + } + } + + CHECK(wrote_explanation); + return false; + } + + HloInstructionPattern op1_; + HloInstructionPattern op2_; +}; + // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. class HloInstructionPatternFusionKindImpl { @@ -805,14 +1516,32 @@ class HloInstructionPatternFusionKindImpl { : kind_(kind) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "with fusion kind " << ToString(kind_); } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kFusion) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_) + << "; it's not a fusion"; + return false; + } + if (inst->fusion_kind() != kind_) { + EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_); + return false; + } + return true; + } + ::xla::HloInstruction::FusionKind kind_; }; @@ -824,47 +1553,211 @@ class HloInstructionPatternTupleIndexImpl { : tuple_index_(tuple_index) {} bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); } bool Match(::xla::HloInstruction* inst, MatchOption option) const { - return inst->opcode() == HloOpcode::kGetTupleElement && - inst->tuple_index() == tuple_index_; + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a GTE with index " << tuple_index_; } private: + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kGetTupleElement) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_ + << "; it's not a GTE at all"; + return false; + } + if (inst->tuple_index() != tuple_index_) { + EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_; + return false; + } + return true; + } + int64 tuple_index_; }; -template -class HloPredicatePatternImpl { +class HloInstructionPatternParameterNumImpl { public: - explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num) + : parameter_num_(parameter_num) {} - bool Match(const ItemType* item, MatchOption option) const { - return pred_(item); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); } - bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is parameter " << parameter_num_; + } private: - Predicate pred_; + template + bool MatchImpl(HloInstructionType* inst, MatchOption option) const { + if (inst->opcode() != HloOpcode::kParameter || + inst->parameter_number() != parameter_num_) { + EXPLAIN << "HloInstruction is not parameter " << parameter_num_; + return false; + } + return true; + } + + int64 parameter_num_; }; -struct PatternFriend; +// Superclass that contains common code used by Op::WithOneUse() and +// Op::WithOneUser(). +class HloInstructionPatternOneUseOrUserImpl { + protected: + bool MatchOneUser(const HloInstruction* inst, MatchOption option) const { + if (inst->user_count() != 1) { + EXPLAIN << "HloInstruction has " << inst->user_count() + << " users, but expected exactly one."; + if (inst->user_count() > 1) { + EXPLAIN << "\nAll users:"; + for (const HloInstruction* user : inst->users()) { + EXPLAIN << "\n - " << InstToString(user); + } + } + return false; + } + return true; + } +}; + +class HloInstructionPatternOneUseImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + if (!MatchOneUser(inst, option)) { + return false; + } + + int64 use_count = absl::c_count_if( + inst->users()[0]->operands(), + [&](const HloInstruction* operand) { return operand == inst; }); + if (use_count != 1) { + EXPLAIN << "HloInstruction is used " << use_count + << " times by its user, but is expected to be used just once: " + << InstToString(inst->users()[0]); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one use"; + } +}; + +class HloInstructionPatternOneUserImpl + : public HloInstructionPatternOneUseOrUserImpl { + public: + bool Match(const HloInstruction* inst, MatchOption option) const { + return MatchOneUser(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which has exactly one user (but possibly is used multiple times by " + "that instruction)"; + } +}; + +// Matches a constant scalar or effective scalar, optionally with a given value. +template +class HloConstantScalarImpl { + public: + explicit constexpr HloConstantScalarImpl(bool match_effective_scalar) + : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {} + + constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar) + : val_(val), match_effective_scalar_(match_effective_scalar) {} + + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return MatchImpl(inst, option); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + *os << "which is a constant " + << (match_effective_scalar_ ? "effective " : "") << "scalar"; + if (val_.has_value()) { + *os << " with value " << *val_; + } + } + + private: + template + bool MatchImpl(InstTy* inst, MatchOption option) const { + const auto* const_inst = DynCast(inst); + if (!const_inst) { + EXPLAIN << "HloInstruction is not a constant"; + return false; + } + if (match_effective_scalar_ && + !ShapeUtil::IsEffectiveScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not an effective scalar"; + return false; + } + if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) { + EXPLAIN << "HloInstruction is not a scalar"; + return false; + } + if (!val_.has_value()) { + return true; + } + + // Check that literal == static_cast(val) and + // val == static_cast(literal). This is sufficient to ensure that + // the two constant scalars are actually "equal". + auto val_literal = LiteralUtil::CreateR0(*val_); + auto literal_r0_or = const_inst->literal().Reshape({}); + auto val_as_literal_ty_or = + val_literal.Convert(const_inst->shape().element_type()); + if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) { + EXPLAIN << "could not construct relevant Literals (how did this happen?)"; + return false; + } + auto literal_r0 = std::move(literal_r0_or).ValueOrDie(); + auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie(); + auto literal_r0_as_val_ty_or = + literal_r0.Convert(val_literal.shape().element_type()); + bool rv = literal_r0_as_val_ty_or.ok() && // + literal_r0_as_val_ty_or.ValueOrDie() == val_literal && + literal_r0 == val_as_literal_ty; + if (!rv) { + EXPLAIN << "HloInstruction's constant value " << literal_r0.ToString() + << " did not match expected value " << *val_; + } + return rv; + } + + absl::optional val_; + bool match_effective_scalar_; +}; // A pattern that matches HloInstructions. template class HloInstructionPattern { private: template - HloInstructionPattern> - AppendImpl(NewImpl new_impl) const { - return HloInstructionPattern< - HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( - AllOf(impl_, std::move(new_impl)), matched_inst_); + auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< + HloInstructionType, decltype(AllOf( + std::declval(), std::move(new_impl)))> { + auto new_allof = AllOf(impl_, std::move(new_impl)); + return HloInstructionPattern( + std::move(new_allof), matched_inst_); } public: @@ -880,6 +1773,9 @@ class HloInstructionPattern { } return true; } + if (inst != nullptr) { + EXPLAIN << "\nin " << InstToString(inst); + } return false; } @@ -891,6 +1787,7 @@ class HloInstructionPattern { } return true; } + EXPLAIN << "\nin " << InstToString(inst); return false; } @@ -907,6 +1804,11 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } + auto WithNumOperands(int64 num_operands) const -> decltype( + this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) { + return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands)); + } + // Modifies the pattern to match only if the instruction does not have the // given opcode. auto WithoutOpcode(HloOpcode opcode) const @@ -915,12 +1817,47 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } + constexpr auto Is(const HloInstruction* instr) const + -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { + return AppendImpl(HloInstructionIsImpl(instr)); + } + // Modifies the pattern to match only if the instruction is a constant. constexpr auto IsConstant() const -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } + constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/false)); + } + + // This does not check that T has the same type as the instruction, so e.g. + // IsConstantScalar(1.0) may match a constant of shape int32[]. + template + constexpr auto IsConstantScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/false))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/false)); + } + + constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(/*match_effective_scalar=*/true)); + } + + template + constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const + -> decltype(this->AppendImpl(HloConstantScalarImpl( + val, /*match_effective_scalar=*/true))) { + return AppendImpl( + HloConstantScalarImpl(val, /*match_effective_scalar=*/true)); + } + // Modifies the pattern to match only if the instruction is not a constant. constexpr auto IsNonConstant() const -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { @@ -937,6 +1874,22 @@ class HloInstructionPattern { HloInstructionPatternShapeImpl(shape)); } + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().EqualTo(shape))) { + return WithShape(Shape().EqualTo(shape)); + } + + // Make this a templated function to work around gcc 4.9.4 template infinite + // recursion bug. + template + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) + -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { + return WithShape(Shape().CompatibleTo(shape)); + } + // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template @@ -951,6 +1904,20 @@ class HloInstructionPattern { operand_index, operand)); } + template + constexpr auto WithBinaryOperandsAnyOrder( + const HloInstructionPattern& op1, + const HloInstructionPattern& op2) const + -> decltype(this->AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, + op2))) { + return AppendImpl( + HloInstructionPatternBinaryOperandsAnyOrderImpl< + OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); + } + // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const @@ -965,17 +1932,34 @@ class HloInstructionPattern { return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } - private: - template - constexpr auto WithPredicate(Predicate pred) const -> decltype( - this->AppendImpl(HloPredicatePatternImpl( - std::move(pred)))) { - return AppendImpl( - HloPredicatePatternImpl(std::move(pred))); + // Modifies the pattern to match only if the instruction is a parameter + // with the given parameter number. + constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( + this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { + return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } - friend struct PatternFriend; + // Modifies the pattern to match if the instruction is used exactly once. + // Does not match if the instruction is used twice by the same user (e.g. + // multiply(x,x)). + constexpr auto WithOneUse() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + return AppendImpl(HloInstructionPatternOneUseImpl()); + } + // Modifies the pattern to match if the instruction is used by exactly one + // other instruction. Will match if the instruction is used twice, so long as + // it's by the same user (e.g. multiply(x,x)). + constexpr auto WithOneUser() const + -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + return AppendImpl(HloInstructionPatternOneUserImpl()); + } + + void DescribeTo(std::ostream* os, int64 indent = 0) const { + impl_.DescribeTo(os, indent); + } + + private: Impl impl_; HloInstructionType** matched_inst_; }; @@ -1016,6 +2000,7 @@ Op(::xla::HloInstruction** matched_inst) { XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) XLA_NULLOP_PATTERN(Iota) +XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. @@ -1047,8 +2032,10 @@ XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(Broadcast) XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(CrossReplicaSum) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Fft) XLA_UNOP_PATTERN(Floor) @@ -1062,13 +2049,13 @@ XLA_UNOP_PATTERN(Negate) XLA_UNOP_PATTERN(Real) XLA_UNOP_PATTERN(Recv) XLA_UNOP_PATTERN(RecvDone) -XLA_UNOP_PATTERN(Reduce) XLA_UNOP_PATTERN(ReducePrecision) XLA_UNOP_PATTERN(Reshape) XLA_UNOP_PATTERN(Reverse) XLA_UNOP_PATTERN(SendDone) XLA_UNOP_PATTERN(Sign) XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Slice) XLA_UNOP_PATTERN(Sort) XLA_UNOP_PATTERN(Tanh) XLA_UNOP_PATTERN(Transpose) @@ -1106,25 +2093,32 @@ XLA_UNOP_PATTERN(Transpose) #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ XLA_BINOP_PATTERN(NAME) \ \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(AnyOf(NAME(lhs, rhs), NAME(rhs, lhs))) { \ - return AnyOf(NAME(lhs, rhs), NAME(rhs, lhs)); \ - } \ - \ template \ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ Rhs&& rhs) \ - ->decltype(AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs))) { \ - return AnyOf(NAME(matched_inst, lhs, rhs), \ - NAME(matched_inst, rhs, lhs)); \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs))) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Convolution) XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(DynamicSlice) XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) @@ -1136,7 +2130,9 @@ XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) +XLA_BINOP_PATTERN(Pad) XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(ReduceWindow) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) @@ -1183,33 +2179,66 @@ XLA_BINOP_PATTERN(ShiftRightLogical) .WithOperand(2, std::forward(arg2)); \ } XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Scatter); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN namespace detail { -struct PatternFriend { - template - static auto ConstantScalar(T constant) -> decltype( - Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate( - std::declval>())) { - std::function pred = - [constant](const HloInstruction* instr) { - const auto& literal = Cast(instr)->literal(); - auto status_or_const = LiteralUtil::CreateR0(constant).Convert( - literal.shape().element_type()); - return status_or_const.ok() && - literal == status_or_const.ConsumeValueOrDie(); - }; - - return Constant() - .WithShape(match::Shape().IsScalar()) - .WithPredicate(std::move(pred)); - } -}; +template +inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) + -> decltype(m.WithOperand(operand_num, std::forward(first_arg))) { + return m.WithOperand(operand_num, std::forward(first_arg)); +} + +template +inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, + Args&&... args) + -> decltype(WithOperands(m.WithOperand(operand_num, + std::forward(first_arg)), + operand_num + 1, std::forward(args)...)) { + return WithOperands( + m.WithOperand(operand_num, std::forward(first_arg)), + operand_num + 1, std::forward(args)...); +} } // namespace detail +#define XLA_VARIADIC_OP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Args&&... args) \ + ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithNumOperands(sizeof...(Args)), \ + 0, std::forward(args)...)) { \ + return detail::WithOperands( \ + Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \ + /*operand_num=*/0, std::forward(args)...); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \ + ->decltype(detail::WithOperands(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithNumOperands(sizeof...(Args)), \ + 0, std::forward(args)...)) { \ + return detail::WithOperands(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithNumOperands(sizeof...(Args)), \ + /*operand_num=*/0, \ + std::forward(args)...); \ + } + +// We could implement all ops as "variadic" ops, but it would make the +// already-bad compile errors even worse. +XLA_VARIADIC_OP_PATTERN(AfterAll); +XLA_VARIADIC_OP_PATTERN(Concatenate); +XLA_VARIADIC_OP_PATTERN(CustomCall); +XLA_VARIADIC_OP_PATTERN(Map) +XLA_VARIADIC_OP_PATTERN(Reduce); +XLA_VARIADIC_OP_PATTERN(Tuple); + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1247,14 +2276,71 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } -template -inline auto ConstantScalar(T constant) - -> decltype(detail::PatternFriend::ConstantScalar(constant)) { - return detail::PatternFriend::ConstantScalar(constant); +// Add overloads for Parameter which take an int64 specifying the parameter +// number. +inline auto Parameter(int64 parameter_num) -> decltype( + Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { + return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); +} +template +inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) + -> decltype(Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num)) { + return Op(matched_inst) + .WithOpcode(HloOpcode::kParameter) + .WithParameterNum(parameter_num); +} + +inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantScalar(); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantScalar(); +} + +template +inline auto ConstantScalar(ScalarTy val) + -> decltype(Op().IsConstantScalar(val)) { + return Op().IsConstantScalar(val); +} + +template +inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) + -> decltype(Op(matched_inst).IsConstantScalar(val)) { + return Op(matched_inst).IsConstantScalar(val); +} + +inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { + return Op().IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsConstantScalar()) { + return Op(matched_inst).IsConstantEffectiveScalar(); +} + +template +inline auto ConstantEffectiveScalar(ScalarTy val) + -> decltype(Op().IsConstantEffectiveScalar(val)) { + return Op().IsConstantEffectiveScalar(val); +} + +template +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, + ScalarTy val) + -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { + return Op(matched_inst).IsConstantEffectiveScalar(val); } } // namespace match } // namespace xla +#undef EXPLAIN +#pragma pop_macro("EXPLAIN") #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock.h b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h new file mode 100644 index 0000000000000000000000000000000000000000..8fe2d10a11b5b2d26ee222c63e0db2d55e361d12 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ + +#include +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +namespace pattern_matcher_gmock_detail { +template +class GmockMatcher { + public: + explicit GmockMatcher(Pattern p) : pattern_(std::move(p)) {} + + // In service of better error messages, list out the overloads explicitly + // rather than just using a template. gMock's polymorphism plus + // pattern_matcher yields some pretty gnarly stuff. + bool MatchAndExplain(const Layout& l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&l, listener); + } + bool MatchAndExplain(const Layout* l, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(l, listener); + } + + bool MatchAndExplain(const Shape& s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&s, listener); + } + bool MatchAndExplain(const Shape* s, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(s, listener); + } + + bool MatchAndExplain(const HloInstruction& instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(&instr, listener); + } + bool MatchAndExplain(const HloInstruction* instr, + ::testing::MatchResultListener* listener) const { + return MatchAndExplainImpl(instr, listener); + } + + void DescribeTo(std::ostream* os) const { pattern_.DescribeTo(os); } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is NOT: "; + DescribeTo(os); + } + + private: + template + bool MatchAndExplainImpl(const T* t, + ::testing::MatchResultListener* listener) const { + MatchOption options{/*.capture=*/true, /*.explain_os=*/listener->stream()}; + return Match(t, pattern_, options); + } + + Pattern pattern_; +}; +} // namespace pattern_matcher_gmock_detail + +template +::testing::PolymorphicMatcher< + pattern_matcher_gmock_detail::GmockMatcher> +GmockMatch(Pattern&& p) { + return ::testing::MakePolymorphicMatcher( + pattern_matcher_gmock_detail::GmockMatcher( + std::forward(p))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_GMOCK_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ca2fb05c1f7ef093c58237cf21fbc7c813a592a --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_gmock_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +namespace m = ::xla::match; +using ::testing::Eq; +using ::testing::Not; + +template +string Describe(const ::testing::Matcher& m) { + std::stringstream ss; + m.DescribeTo(&ss); + return ss.str(); +} + +template +string Explain( + const MatchedTy& val, + const ::testing::Matcher::type>& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(val, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(val, &listener)); + return listener.str(); +} + +// This file tests the GmockMatch function. The actual explanation and +// description returned by matchers is tested in pattern_matchers_test. +TEST(PatternMatcherGmock, MatchShape) { + Shape s = ShapeUtil::MakeShape(F32, {10, 100}); + // You can pass const Shape& or a const Shape*. + EXPECT_THAT(s, GmockMatch(m::Shape())); + EXPECT_THAT(&s, Not(GmockMatch(m::Shape().WithElementType(F16)))); + EXPECT_THAT(Describe(GmockMatch(m::Shape().IsArray())), + "a shape that represents an array"); +} + +TEST(PatternMatcherGmock, MatchLayout) { + Layout l = LayoutUtil::MakeLayout({0, 1}); + EXPECT_THAT(l, GmockMatch(m::Layout())); + EXPECT_THAT(&l, Not(GmockMatch(m::Layout().WithSparseFormat()))); + EXPECT_THAT(Describe(GmockMatch(m::Layout().WithSparseFormat())), + "a layout with format SPARSE"); +} + +TEST(PatternMatchGmock, MatchInstruction) { + auto instr = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {42}), "p"); + EXPECT_THAT(instr.get(), GmockMatch(m::Parameter())); + EXPECT_THAT(*instr, GmockMatch(m::Parameter(0))); + EXPECT_THAT(*instr, Not(GmockMatch(m::Parameter(1)))); + EXPECT_THAT(Describe(GmockMatch(m::Parameter())), + "an HloInstruction with opcode parameter"); +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 3ab7b7fd7168d7ddd1470fdb03a04ba7b171fddb..186ef0c7911a2724df810780e018f52586e3e6a8 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -14,14 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { +namespace m = match; + TEST(PatternMatcherTest, AddOp) { constexpr char kModuleStr[] = R"(HloModule two_plus_two_module ENTRY %two_plus_two_computation () -> f32[] { @@ -229,23 +233,74 @@ TEST(PatternMatcherTest, AnyOf) { } TEST(PatternMatcherTest, ConstantScalar) { - constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); - auto* root = hlo_module->entry_computation()->root_instruction(); - - EXPECT_TRUE(Match(root, match::ConstantScalar(42))); - EXPECT_FALSE(Match(root, match::ConstantScalar(41))); - EXPECT_FALSE(Match(root, match::ConstantScalar(0))); -} + using match::ConstantEffectiveScalar; + using match::ConstantScalar; + using match::Op; + using match::Tuple; -TEST(PatternMatcherTest, NoMatchConstantScalar) { constexpr char kModuleStr[] = R"( - HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + HloModule test_module + ENTRY test { + a = s32[] constant(1) + b = s32[1,1] constant(s32[1,1]{{2}}) + c = s32[1,2] constant(s32[1,2]{{2,2}}) + d = f32[] constant(1) + e = f32[] constant(1.25) + ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e) + })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); - EXPECT_FALSE(Match(root, match::ConstantScalar(42))); + const HloInstruction* a = root->operand(0); + const HloInstruction* b = root->operand(1); + const HloInstruction* c = root->operand(2); + const HloInstruction* d = root->operand(3); + const HloInstruction* e = root->operand(4); + EXPECT_TRUE(Match(a, ConstantScalar())); + EXPECT_TRUE(Match(a, ConstantScalar(1))); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1))); + EXPECT_FALSE(Match(a, ConstantScalar(2))); + EXPECT_FALSE(Match(a, ConstantScalar(2.01))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2))); + EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01))); + + EXPECT_FALSE(Match(b, ConstantScalar())); + EXPECT_FALSE(Match(b, ConstantScalar(2))); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar())); + EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2))); + + EXPECT_FALSE(Match(c, ConstantScalar())); + EXPECT_FALSE(Match(c, ConstantScalar(2))); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar())); + EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2))); + + EXPECT_TRUE(Match(d, ConstantScalar(1))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1))); + EXPECT_TRUE(Match(d, ConstantScalar(1.0))); + EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0))); + + EXPECT_TRUE(Match(e, ConstantScalar(1.25f))); + EXPECT_TRUE(Match(e, ConstantScalar(1.25))); + EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25))); + EXPECT_FALSE(Match(e, ConstantScalar(1))); + EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1))); + + const HloInstruction* instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr))); + EXPECT_EQ(instr, a); + + instr = nullptr; + EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1))); + EXPECT_EQ(instr, a); } TEST(PatternMatcherTest, MultiplyAnyOrder) { @@ -267,6 +322,15 @@ TEST(PatternMatcherTest, MultiplyAnyOrder) { root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); EXPECT_TRUE(Match( root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); + + // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call + // e.g. IsNonConstant() on it. + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); + EXPECT_TRUE( + Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52)) + .IsNonConstant())); } TEST(PatternMatcherTest, AnyOfShortCircuit) { @@ -315,14 +379,22 @@ TEST(PatternMatcherTest, AllOf) { TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); + auto f16_scalar = ShapeUtil::MakeShape(F16, {}); + auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar); + auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar); auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); - auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); ASSERT_TRUE(Match(root, scalar_pattern)); ASSERT_TRUE(Match(root, f16_pattern)); - EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern))); - EXPECT_TRUE(Match(root, AllOf(f16_pattern, scalar_pattern))); + ASSERT_TRUE(Match(root, f16_compatible_pattern)); + EXPECT_TRUE(Match(root, AllOf(scalar_pattern, f16_pattern, + f16_compatible_pattern))); + EXPECT_TRUE( + Match(root, AllOf(f16_pattern, f16_compatible_pattern, + scalar_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE(Match( + root, AllOf(Broadcast(Op()), f16_compatible_pattern))); EXPECT_FALSE( Match(root, AllOf(Broadcast(Op()), scalar_pattern))); } @@ -394,5 +466,470 @@ TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { EXPECT_EQ(nullptr, addend2); } +TEST(PatternMatcherTest, TestConcat) { + using match::Concatenate; + using match::ConstantScalar; + using match::Op; + using match::Reshape; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c4 = u32[] constant(4) + r1 = u32[1] reshape(c1) + r2 = u32[1] reshape(c2) + r3 = u32[1] reshape(c3) + r4 = u32[1] reshape(c4) + ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + ASSERT_TRUE(Match( + root, + Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)), + Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4))))); + ASSERT_FALSE(Match( + root, + Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)), + Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4))))); + ASSERT_FALSE(Match( + root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)), + Reshape(ConstantScalar(3))))); + ASSERT_FALSE(Match( + root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)), + Reshape(ConstantScalar(4))))); +} + +template +string Description(const Pattern& pattern) { + std::stringstream ss; + pattern.DescribeTo(&ss); + return ss.str(); +} + +template +string Explanation(Elem* elem, const Pattern& pattern) { + std::stringstream ss; + MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss}; + Match(elem, pattern, options); + return ss.str(); +} +template +string Explanation(const std::unique_ptr& elem, const Pattern& pattern) { + return Explanation(elem.get(), pattern); +} +template +string Explanation(const Elem& elem, const Pattern& pattern) { + return Explanation(&elem, pattern); +} + +// Helper macro for checking a pattern's description and the explanation printed +// when attempting to match (and presumably failing) on a given object. +// +// We use a macro rather than a function because we want good line numbers in +// errors. We use this rather than writing a helper that returns a pair of +// (description, explanation) and doing something like +// +// EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...)); +// +// because EXPECT_EQ prints a unified diff if multiline string comparison fails, +// while EXPECT_THAT does not. This unified diff makes the errors much easier +// to read. +#define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \ + expected_explanation) \ + do { \ + EXPECT_EQ(Description(pattern), (expected_desc)); \ + EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \ + } while (0) + +TEST(PatternMatcherTest, LayoutDescribeToAndExplain) { + auto layout = LayoutUtil::MakeLayout({1, 2}); + auto layout2 = LayoutUtil::MakeLayout({2, 2}); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Layout(), + "a layout", "Layout is null"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout), + "a layout equal to {1,2}", + "Layout {2,2} is not equal to expected {1,2}"); + EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(), + "a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE"); + EXPECT_DESC_AND_EXPLANATION(layout, + m::Layout().EqualTo(&layout).WithSparseFormat(), + "a layout:\n" + " * equal to {1,2} AND\n" + " * with format SPARSE", + "Layout has format DENSE but expected SPARSE"); +} + +TEST(PatternMatcherTest, ShapeDescribeToAndExplain) { + auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1}); + auto layout = shape.layout(); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), m::Shape(), + "a shape", "Shape is null"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}), + m::Shape().CompatibleTo(&shape), + "a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16), + "a shape with element type F16", + "Shape does not have element type F16\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(), + "a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(), + "a shape that represents an array", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(), + "a shape that represents a tuple", + "Shape is not a tuple\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(), + "a shape that is an effective scalar", + "Shape is not an effective scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42), + "a shape that has 42 dimensions", + "Shape does not have rank 42\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0), + "a shape that is a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(), + "a shape:\n" + " * that has 1 dimension AND\n" + " * that represents an array", + "Shape does not have rank 1\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), + m::Shape().IsArray().WithRank(1), + "a shape:\n" + " * that represents an array AND\n" + " * that has 1 dimension", + "Shape is not an array\n" + "in ()"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}), + m::Shape().WithLayoutEqualTo(&layout), + "a shape with\n a layout equal to {0,1}", + "Layout {1,0} is not equal to expected {0,1}\n" + "in f32[1,2]{1,0}"); + EXPECT_DESC_AND_EXPLANATION( + shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()), + "a shape with\n a layout with format SPARSE", + "Layout has format DENSE but expected SPARSE\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeEqualTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape equal to f32[1,2]{0,1}", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeEqualTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape equal to f32[1,2]{0,1}", + "Shape not equal to f32[1,2]{0,1}\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION(shape, + m::Shape().WithSubshapeCompatibleTo({10}, &shape), + "a shape with subshape at index {10} which is\n" + " a shape compatible with f32[1,2]", + "No subshape at {10}\n" + "in f32[1,2]{0,1}"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}), + m::Shape().WithSubshapeCompatibleTo({0}, &shape), + "a shape with subshape at index {0} which is\n" + " a shape compatible with f32[1,2]", + "Shape not compatible with f32[1,2]\n" + "in f32[2,2]{1,0}\n" + "in subshape at {0}\n" + "in (f32[2,2])"); + EXPECT_DESC_AND_EXPLANATION( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}), + m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()), + "a shape with subshape at index {0,0} which is\n" + " a shape that represents a scalar", + "Shape is not a scalar\n" + "in f32[1,2]{0,1}\n" + "in subshape at {0,0}\n" + "in ((f32[1,2]))"); +} + +std::unique_ptr SetName(absl::string_view name, + std::unique_ptr instr) { + instr->SetAndSanitizeName(string(name)); + return instr; +} + +TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) { + std::unique_ptr iota = + SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}), + /*iota_dimension=*/0)); + std::unique_ptr constant = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + EXPECT_DESC_AND_EXPLANATION(static_cast(nullptr), + m::Op(), "an HloInstruction", + "HloInstruction* is null"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"), + "an HloInstruction named \"foo\"", + "HloInstruction not named \"foo\"\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd), + "an HloInstruction with opcode add", + "HloInstruction doesn't have opcode add\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().IsNonConstant(), + "an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42), + "an HloInstruction with 42 operands", + "HloInstruction doesn't have 42 operands\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()), + "an HloInstruction outputting\n" + " a shape that represents a tuple", + "Shape is not a tuple\n" + "in s32[42]{0}\n" + "in output shape\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)), + "an HloInstruction with operand 2 which is:\n" + " an HloInstruction with opcode add", + "desired operand index 2 is out of bounds\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}), + HloOpcode::kAdd, constant.get(), + constant.get())), + m::Op().WithOperand(1, m::Op().IsNonConstant()), + "an HloInstruction with operand 1 which is:\n" + " an HloInstruction with any opcode other than constant", + "HloInstruction has opcode constant, expected anything else\n" + "in c = s32[] constant(0)\n" + "in operand 1\n" + "in a = s32[] add(s32[] c, s32[] c)"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop), + "an HloInstruction with fusion kind kLoop", + "HloInstruction does not have fusion kind kLoop; it's not a fusion\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + iota, m::Op().WithTupleIndex(42), + "an HloInstruction which is a GTE with index 42", + "HloInstruction is not a GTE with index 42; it's not a GTE at all\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(), + "an HloInstruction which is a constant scalar", + "HloInstruction is not a constant\n" + "in i = s32[42]{0} iota(), iota_dimension=0"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR1({1, 2}))), + m::Op().IsConstantEffectiveScalar(), + "an HloInstruction which is a constant effective scalar", + "HloInstruction is not an effective scalar\n" + "in c = s32[2]{0} constant({1, 2})"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))), + m::Op().IsConstantScalar(42), + "an HloInstruction which is a constant scalar with value 42", + "HloInstruction's constant value 10 did not match expected value 42\n" + "in c = s32[] constant(10)"); + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))), + m::Op().IsConstantEffectiveScalar(1.25), + "an HloInstruction which is a constant effective scalar with value 1.25", + "HloInstruction's constant value 2.25 did not match expected value 1.25\n" + "in c = f64[] constant(2.25)"); + EXPECT_DESC_AND_EXPLANATION( + constant, m::Op().Is(iota.get()), + absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)"), + absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x", + absl::Hex(iota.get()), + " (i = s32[42]{0} iota(), iota_dimension=0)\n" + "in c = s32[] constant(0)")); +} + +TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + EXPECT_DESC_AND_EXPLANATION( + SetName("a", HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + SetName("b", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction named \"b\"\n" + " - an HloInstruction named \"bar\"", + "HloInstruction's operands (ignoring order) did not match second " + "matcher. Specifically,\n" + " - an HloInstruction named \"bar\"\n" + "does not match LHS:\n" + " - HloInstruction not named \"bar\"\n" + " in b = s32[] constant(0)\n" + "does not match RHS:\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)\n" + "in a = s32[] add(s32[] b, s32[] c)"); + + EXPECT_DESC_AND_EXPLANATION( + SetName("a", + HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + HloInstruction::CreateParameter(0, scalar_s32, "p").get(), + SetName("c", HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0))) + .get())), + m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()), + "an HloInstruction:\n" + " * with opcode add AND\n" + " * with two operands in either order:\n" + " - an HloInstruction which is a constant scalar\n" + " - an HloInstruction with opcode constant", + "HloInstruction's LHS operand did not match either of the two matchers. " + "Specifically,\n" + " - an HloInstruction which is a constant scalar\n" + "does not match LHS:\n" + " - HloInstruction is not a constant\n" + " in p = s32[] parameter(0)\n" + "and\n" + " - an HloInstruction with opcode constant\n" + "does not match LHS:\n" + " - HloInstruction doesn't have opcode constant\n" + " in p = s32[] parameter(0)\n" + "in a = s32[] add(s32[] p, s32[] c)"); +} + +TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { + EXPECT_DESC_AND_EXPLANATION( + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))), + m::AnyOf(m::Op().WithName("foo"), + m::Op().WithName("bar")), + "any of:\n" + " - an HloInstruction named \"foo\" OR\n" + " - an HloInstruction named \"bar\"", + "None of the following matchers succeeded:\n" + "Matcher #1\n" + " - an HloInstruction named \"foo\"\n" + "failed with\n" + " - HloInstruction not named \"foo\"\n" + " in c = s32[] constant(0)\n" + "Matcher #2\n" + " - an HloInstruction named \"bar\"\n" + "failed with\n" + " - HloInstruction not named \"bar\"\n" + " in c = s32[] constant(0)"); +} + +TEST(PatternMatcherTest, Parameter) { + auto param = + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + auto non_param = + SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + EXPECT_FALSE(Match(param.get(), m::Parameter(0))); + EXPECT_TRUE(Match(param.get(), m::Parameter())); + EXPECT_TRUE(Match(param.get(), m::Parameter(1))); + EXPECT_FALSE(Match(non_param.get(), m::Parameter())); + EXPECT_FALSE(Match(non_param.get(), m::Parameter(1))); + + EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1), + "an HloInstruction:\n" + " * with opcode parameter AND\n" + " * which is parameter 1", + "HloInstruction doesn't have opcode parameter\n" + "in c = s32[] constant(0)"); + EXPECT_EQ(Explanation(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "p0"), + m::Parameter(1)), + "HloInstruction is not parameter 1\n" + "in p0 = f32[] parameter(0)"); +} + +TEST(PatternMatcherTest, OneUseAndOneUser) { + auto param = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUse(), + "an HloInstruction which has exactly one use", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_DESC_AND_EXPLANATION( + param, m::Op().WithOneUser(), + "an HloInstruction which has exactly one user (but possibly is used " + "multiple times by that instruction)", + "HloInstruction has 0 users, but expected exactly one.\n" + "in p0 = f32[] parameter(0)"); + + { + auto reshape = + SetName("r", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + + auto reshape1 = + SetName("r1", HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {1}), param.get())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser())); + + const char* kMultipleUserExplanation = + "HloInstruction has 2 users, but expected exactly one.\n" + "All users:\n" + " - r = f32[1]{0} reshape(f32[] p0)\n" + " - r1 = f32[1]{0} reshape(f32[] p0)\n" + "in p0 = f32[] parameter(0)"; + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + kMultipleUserExplanation); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()), + kMultipleUserExplanation); + } + + auto add = SetName("add", HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, + param.get(), param.get())); + EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser())); + EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); + EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), + "HloInstruction is used 2 times by its user, but is expected to be " + "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "in p0 = f32[] parameter(0)"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index c522e7ae23b734090f85d241bf365fccc37f0adb..c227106511c2c17b44569d3b696cd7d764226e81 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -59,20 +59,15 @@ string CanonicalPlatformName(const string& name) { /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { - se::MultiPlatformManager::PlatformMap platform_map; - se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms( - [&platform_map](se::MultiPlatformManager::PlatformMap* map) { - platform_map = *map; - return se::port::Status::OK(); - }); - if (platform_map.empty()) { + std::vector all_platforms = + se::MultiPlatformManager::AllPlatforms(); + if (all_platforms.empty()) { LOG(WARNING) << "no executor platforms available: platform map is empty"; } // Gather all platforms which have an XLA compiler. std::vector platforms; - for (auto& platform_pair : platform_map) { - auto* platform = platform_pair.second; + for (se::Platform* platform : all_platforms) { auto compiler_status = Compiler::GetForPlatform(platform); if (compiler_status.ok()) { platforms.push_back(platform); @@ -222,8 +217,8 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { // fix the number of devices to one. However we do let the user override // this behavior to help run tests on the host that run models in parallel // across multiple devices. - device_count = legacy_flags::GetDebugOptionsFromFlags() - .xla_force_host_platform_device_count(); + device_count = + GetDebugOptionsFromFlags().xla_force_host_platform_device_count(); } std::vector stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 688cceff0cd10df62a4093f00ad3331ca77652e0..b70cb7057477a338bfb36ebab76237b30d018e41 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -111,7 +111,7 @@ StatusOr ReducePrecisionInsertion::insert_on_inputs( VLOG(2) << "Adding to operand " << i << ": " << operand; if (!is_valid_shape(operand->shape())) { - VLOG(2) << "Skipped: value is not an F32 vector"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } @@ -168,7 +168,7 @@ StatusOr ReducePrecisionInsertion::insert_on_outputs( << instruction->ToString(); if (!is_valid_shape(instruction->shape())) { - VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + VLOG(2) << "Skipped: value is not of type F32"; continue; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 0b4e82e8d606cf2cacfab42d07c2201939d5e10b..76c6a87f176ec9c6f8e49c25278c6dad703e3c7c 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -118,13 +118,7 @@ class ReducePrecisionInsertion : public HloModulePass { // equivalent behavior can be obtained by adding ReducePrecision // instructions after the instructions that pull the F32 arrays out of // the tuples. - // - // TODO(b/64093391): Remove the IsScalar check once this won't cause - // failures on the GPU backend if the ReducePrecision instruction ends up - // inserted between a scalar constant and the init_value argument of a - // Reduce operation. - return shape.element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(shape); + return shape.element_type() == PrimitiveType::F32; } // Is this instruction one such that following or preceding it with a new diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index 69e4b534bd8e3aeab8b729f3e594a10b4368f15f..efeec96571455d8a9e4b7837dd7286392c12f1a3 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -54,7 +54,34 @@ TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); +} + +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryScalarInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -84,7 +111,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -113,7 +140,7 @@ TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -146,7 +173,7 @@ TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { HloInstruction* d = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -178,7 +205,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -215,7 +242,7 @@ TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { HloInstruction* c = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_cos, b_cos)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -242,7 +269,7 @@ TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -268,7 +295,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -294,7 +321,7 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { HloInstruction* b = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, a, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected state before adding ops. @@ -321,7 +348,7 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 5, 10)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -349,7 +376,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { HloInstruction* y = builder.AddInstruction( HloInstruction::CreateReducePrecision(shape, x, 8, 23)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Confirm expected graph before adding ops. @@ -375,7 +402,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -411,7 +438,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. @@ -458,7 +485,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); HloInstruction* y = builder.AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); // Manually fuse the kCos operation into a fusion operation. diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index fcf269eee925c2ddb7511d70e71bd815e4b8c24a..341659b15c4c7355d39739ee171a4a749d87e929 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,9 +34,10 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ReshapeMoverTest : public HloVerifiedTestBase {}; +class ReshapeMoverTest : public HloTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -50,12 +51,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -74,6 +75,7 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { // Verifies that the reshape is not moved, since rng0 is trivially reshapable // and therefore there is no nontrivial reshapes to move. TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto rng0 = builder.AddInstruction(HloInstruction::CreateRng( @@ -92,18 +94,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); } TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -117,12 +120,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -130,6 +133,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -143,11 +147,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -177,6 +181,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { // | // reshape4 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction( @@ -196,12 +201,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); @@ -221,6 +226,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { // Verifies that the reshape0 does not sink below add, because param1 is not // trivially reshapable nor is a Reshape/Transpose. TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -232,11 +238,11 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -257,6 +263,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // Verifies that we don't unnecessarily sink reshapes, which are in fact // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -275,12 +282,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -309,6 +316,7 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { // // (note that reshape1 here is trivial). TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -320,12 +328,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -348,6 +356,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { // For now we treat it as non-trivial, so we verify that we don't sink the // reshapes in this case. TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -362,12 +371,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(const1))); @@ -376,6 +385,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -389,14 +399,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -405,6 +415,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); @@ -423,13 +434,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -438,6 +449,7 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { } TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { + auto m = CreateNewVerifiedModule(); HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {}); auto pred_shape = ShapeUtil::MakeShape(PRED, {}); @@ -452,11 +464,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -477,6 +489,7 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { // // We expect reshape{0,1} AND reshape{2,3} to be lifted. TEST_F(ReshapeMoverTest, MultiplePasses) { + auto m = CreateNewVerifiedModule(); auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); @@ -500,14 +513,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto computation = module().AddEntryComputation(builder.Build()); + auto computation = m->AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(m.get()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -526,11 +539,11 @@ TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Transpose(op::Multiply())); } @@ -555,8 +568,8 @@ TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_FALSE(changed); } @@ -580,10 +593,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); } @@ -597,10 +610,10 @@ TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { } )"; - ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(m.get())); EXPECT_TRUE(changed); - EXPECT_THAT(module().entry_computation()->root_instruction(), + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Reshape(op::Add())); } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index de7aee262e61195b37099fc661a95508d0539e18..11c2f8392d285095816dd5d61f7029c1bfd158d4 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -392,7 +392,8 @@ StatusOr ScatterExpander::ExpandScatter( [&](HloInstruction* induction_var, const std::vector& loop_state) { return ScatterLoopBody(scatter, induction_var, loop_state); - }); + }, + scatter->metadata()); TF_ASSIGN_OR_RETURN(std::vector scatter_loop_result, scatter_loop_result_status); return scatter_loop_result.front(); diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 559a85dccfef27816e7dbf746fd71c44bbf46f60..533af060bc9f943e5bc2882db626e25c77484029 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -25,7 +25,7 @@ class ScatterExpander : public HloModulePass { absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; - private: + protected: StatusOr ExpandScatter(HloInstruction* scatter); }; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index b27a92f2a0761a2bccd97eb2c0467ead27565c37..5ec7fe2adedac2fc3d8a7588e853dba90e99006f 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -23,9 +23,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -175,7 +176,14 @@ Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, Status Service::Unregister(const UnregisterRequest* arg, UnregisterResponse* result) { - return allocation_tracker_.Unregister(arg->data()); + Status status; + for (auto& data : arg->data()) { + Status unregister_status = allocation_tracker_.Unregister(data); + if (!unregister_status.ok() && status.ok()) { + status = unregister_status; + } + } + return status; } // Deconstructs a previously-allocated global handle. @@ -207,7 +215,7 @@ Status Service::ValidateResultShape(const Shape& client_shape, StatusOr>> Service::ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors) { + absl::Span stream_executors) const { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -268,8 +276,8 @@ StatusOr> Service::CreateModuleConfig( } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { - const auto& shape_with_output_layout = - execution_options->shape_with_output_layout(); + const Shape shape_with_output_layout( + execution_options->shape_with_output_layout()); TF_RETURN_IF_ERROR( ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( @@ -285,7 +293,7 @@ StatusOr> Service::CreateModuleConfig( config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); } else { - config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config->set_debug_options(GetDebugOptionsFromFlags()); } if (execute_backend_ != nullptr && @@ -341,19 +349,19 @@ StatusOr>> Service::BuildExecutables( } CHECK_EQ(module_protos.size(), module_configs.size()); - std::vector> modules; + auto module_group = + absl::make_unique(module_protos[0]->name()); for (int64 i = 0; i < module_protos.size(); ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - HloModule::CreateFromProto(*proto, config)); - modules.push_back(std::move(module)); + TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); + module_group->push_back(std::move(module)); } TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); + backend->compiler()->Compile(std::move(module_group), + std::move(executors), device_allocator)); for (size_t i = 0; i < module_protos.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { @@ -590,7 +598,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments) { + absl::Span arguments) const { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -634,7 +642,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, arg->requests(i).execution_options(); const ExecuteGraphRequest& request = arg->requests(i); TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; - TF_RET_CHECK(request.computation().has_program_shape()) + TF_RET_CHECK(request.computation().has_host_program_shape()) << "programe shape may not be empty"; // Get the executors. @@ -651,9 +659,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().program_shape(), - replicated_arguments.front(), - request.execution_options())); + CreateModuleConfig( + ProgramShape{request.computation().host_program_shape()}, + replicated_arguments.front(), request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -738,9 +746,9 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%d) exceeds the number of available devices " - "on the target (%d)", - arg->device_count(), available_device_count); + "Requested logical device count (%d) with replica count (%d) exceeds " + "the number of available physical devices on the target (%d)", + arg->device_count(), replica_count, available_device_count); } for (int64 i = 0; i < arg->device_count(); ++i) { @@ -753,38 +761,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - ExecuteGraphParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - -Status Service::PickParallelResponse( - const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { - // The "result device" selection is a bit hacky, but better than assuming it - // is device 0. We have b/76035356 for restructuring the client API to clean - // up the current asymmetries and support more functionalities. - for (int64 i = 0; i < parallel_result.responses_size(); ++i) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.ResolveForReplica( - parallel_result.responses(i).output(), 0)); - const Shape& shape = buffer->on_host_shape(); - if (!ShapeUtil::IsEmptyTuple(shape)) { - *result = parallel_result.responses(i); - VLOG(3) << "Fetching result from device " << i << ": " - << ShapeUtil::HumanString(shape); - return Status::OK(); - } - } - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - VLOG(1) << "Defaulting to device 0 result"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -810,7 +786,7 @@ StatusOr> Service::BuildExecutable( } TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(module_proto, *module_config)); + CreateModuleFromProto(module_proto, *module_config)); TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module)); @@ -829,32 +805,33 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute-graph request"; - +Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { + VLOG(1) << "running compile request"; if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("programe shape may not be empty"); } - // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); + return InvalidArgument( + "The compile request does not support multiple device handles."); } - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); + std::vector argument_shapes; + argument_shapes.reserve(arg->input_shape_with_layout_size()); + std::vector argument_shape_ptrs; + for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) { + argument_shapes.push_back(Shape(shape_proto)); + argument_shape_ptrs.push_back(&argument_shapes.back()); + } TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(arg->computation().program_shape(), - replicated_arguments.front(), - arg->execution_options())); + std::unique_ptr module_config, + CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, + argument_shape_ptrs, &arg->execution_options())); + VLOG(3) << "Compile created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, @@ -863,6 +840,48 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + *result->mutable_handle() = compilation_cache_.Insert(std::move(executable)); + + VLOG(1) << "successfully completed 'compile' request"; + return Status::OK(); +} + +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { + VLOG(1) << "running execute request"; + if (!arg->has_handle()) { + return InvalidArgument("execution handle should not be empty"); + } + TF_ASSIGN_OR_RETURN(auto executable, + compilation_cache_.LookUp(arg->handle())); + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + // Check that the replicated_arguments has the same shape and layout as the + // module config used when creating the exectuable. + const int64 num_module_args = + executable->module_config().entry_computation_layout().parameter_count(); + if (num_module_args != arg->arguments_size()) { + return InvalidArgument( + "The executable expects %lld arguments, but sees %lld.", + num_module_args, arg->arguments_size()); + } + for (int64 i = 0; i < num_module_args; i++) { + const Shape& shape_module = + executable->module_config().entry_computation_layout().parameter_shape( + i); + const Shape& shape_arg = replicated_arguments.front()[i]->on_host_shape(); + if (!ShapeUtil::Equal(shape_module, shape_arg)) { + return InvalidArgumentStrCat( + "The executable exepcts the ", i, "th argument in shape ", + ShapeUtil::HumanStringWithLayout(shape_module), " but sees ", + ShapeUtil::HumanStringWithLayout(shape_arg)); + } + } + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( execute_backend_->default_stream_executor())); @@ -876,9 +895,10 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + arg->computation().name(), result->mutable_profile())); + ExecuteAndRegisterResult(executable.get(), replicated_arguments, + execute_backend_.get(), + "result of " + executable->module().name(), + result->mutable_profile())); if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN( @@ -890,7 +910,7 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - VLOG(1) << "successfully completed 'execute-graph' request"; + VLOG(1) << "successfully completed 'execute' request"; return Status::OK(); } @@ -914,14 +934,14 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - const Shape* return_shape; + Shape return_shape; if (arg->has_shape_with_layout()) { - if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { + return_shape = Shape(arg->shape_with_layout()); + if (!LayoutUtil::HasLayout(return_shape)) { return InvalidArgument("shape_with_layout must have layout if present."); } - return_shape = &arg->shape_with_layout(); } else { - return_shape = &shaped_buffer->on_host_shape(); + return_shape = Shape(shaped_buffer->on_host_shape()); } TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( @@ -932,30 +952,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) { *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal.Relayout(*return_shape).ToProto(); + result_literal.Relayout(return_shape).ToProto(); } return Status::OK(); } -namespace { - -// Creates a clone of the given shaped buffer with the given device ordinal. The -// shape and DeviceMemoryBase values of the clone are identical to the original. -std::unique_ptr CloneShapedBufferOnDevice( - const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = absl::make_unique( - shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), - shaped_buffer.platform(), device_ordinal); - clone->buffers() = shaped_buffer.buffers(); - return clone; -} - -} // namespace - Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(Literal literal, @@ -1044,11 +1049,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, executor = replicas[arg->replica_id()]; } - auto literal = Literal::CreateFromShape(arg->shape_with_layout()); + auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout())); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), literal)); + executor, Shape(arg->shape_with_layout()), literal)); *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1063,15 +1068,15 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("program shape may not be empty"); } - if (arg->computation().program_shape().parameters_size() != 0) { + if (arg->computation().host_program_shape().parameters_size() != 0) { return InvalidArgument( "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().program_shape(); + ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1081,7 +1086,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModuleConfig config(program_shape); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); HloEvaluator evaluator; TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( @@ -1102,7 +1107,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); - *result->mutable_shape() = buffer->on_host_shape(); + *result->mutable_shape() = buffer->on_host_shape().ToProto(); return Status::OK(); } @@ -1111,14 +1116,14 @@ Status Service::GetComputationGraphStats( if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); } - if (!arg->computation().has_program_shape()) { + if (!arg->computation().has_host_program_shape()) { return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().program_shape()); + HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()}); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - HloModule::CreateFromProto(arg->computation(), config)); + CreateModuleFromProto(arg->computation(), config)); hlo_graph_dumper::MaybeDumpHloModule(*module, "computation statistics subject"); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 1f62fad4c8079eba7013b3f647fe19bbc031fc77..11e1a79552fbd944ab28da129b08cfe676fb08e9 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -22,11 +22,12 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" +#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" @@ -90,11 +91,14 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Executes a computation with the provided global data passed as - // immutable arguments. The request contains the whole computation graph. - // Returns global data output and execution timing. - Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + // Compiles a computation into an executable. The request contains the whole + // computation graph. Returns the handle to the executable. + Status Compile(const CompileRequest* arg, CompileResponse* result) override; + + // Executes an executable with the provided global data passes as immutable + // arguments. The request contains the handle to the executable. Returns + // global data output and execution timing. + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each @@ -179,10 +183,6 @@ class Service : public ServiceInterface { absl::Span arguments, const ExecutionOptions& execution_options); - // Picks a parallel response and fills the result. - Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, - ExecuteResponse* result); - // Prepare the executors for executing parallel. StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - absl::Span arguments); + absl::Span arguments) const; protected: friend class LocalExecutable; @@ -208,7 +208,7 @@ class Service : public ServiceInterface { StatusOr>> ResolveAndValidateArguments( absl::Span arguments, - absl::Span stream_executors); + absl::Span stream_executors) const; // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -254,11 +254,6 @@ class Service : public ServiceInterface { Backend* backend, absl::Span device_handles, absl::Span result_tags, ExecutionProfile* profile); - // Executes a single computation which has more than one target device. - // The N devices are expected to all return an empty tuple, but one, which - // will be the result of this computation. - Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. @@ -281,6 +276,9 @@ class Service : public ServiceInterface { ServiceOptions options_; + // Cache containing previously built Executables. + CompilationCache compilation_cache_; + // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index e379911462f1d2caa53f708a6ebf8b7363dc2fc3..7e7282a737041458aed39b0054f901c23aa87d7a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -207,7 +207,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, padded_dilated_base, dilated_window, dim.stride()); } - return ShapeUtil::MakeShape(element_type, output_dimensions); + return ShapeUtil::MakeValidatedShape(element_type, output_dimensions); } } // namespace @@ -391,17 +391,6 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } -/* static */ StatusOr ShapeInference::InferAfterAllShape( - absl::Span arg_shapes) { - for (const Shape* arg_shape : arg_shapes) { - if (arg_shape->element_type() != TOKEN) { - return InvalidArgument( - "Operands of token instructions must be TOKEN types."); - } - } - return ShapeUtil::MakeTokenShape(); -} - /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -919,6 +908,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, + broadcast_dimensions); + case HloOpcode::kSubtract: case HloOpcode::kAdd: case HloOpcode::kAtan2: @@ -929,6 +921,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + if (lhs.element_type() == PRED || rhs.element_type() == PRED) { + return InvalidArgument( + "Expected element type in shape to be arithmetic type for " + "operation %s; got PRED.", + HloOpcodeString(opcode)); + } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -1020,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, switch (opcode) { case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); - result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); + result.mutable_tuple_shapes()->reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } @@ -1029,17 +1027,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]), - ShapeUtil::HumanString(*operand_shapes[1])); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } @@ -1557,6 +1560,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); + if (feature_group_count <= 0) { + return InvalidArgument( + "feature_group_count must be a positive number, got %d", + feature_group_count); + } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1566,8 +1574,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" - "Window: %s", - window.DebugString()); + "Numbers: %s", + dnums.DebugString()); + } + + if (dnums.input_spatial_dimensions_size() != + dnums.output_spatial_dimensions_size()) { + return InvalidArgument( + "Both input and output of convolution must have same number of " + "dimensions.\nNumbers: %s", + dnums.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1586,8 +1602,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( - "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs)); + "The RHS argument to a convolution should have rank %d; rhs: %s.", + num_dims, ShapeUtil::HumanString(rhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1662,14 +1678,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features * feature_group_count) { + if (input_features % feature_group_count != 0 || + input_features / feature_group_count != kernel_input_features) { return InvalidArgument( - "Expected LHS feature dimension (value %d) to match RHS " - "input feature dimension * feature_group_count (value %d * %d = %d); " + "Expected LHS feature dimension (value %d) to be a multiple of " + "feature_group_count (value %d), and LHS feature dimension / " + "feature_group_count = RHS feature dimension (value %d); " "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, feature_group_count, - kernel_input_features * feature_group_count, + input_features, feature_group_count, kernel_input_features, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } @@ -2003,6 +2020,25 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return operand_shape; } +/* static */ StatusOr ShapeInference::InferGetDimensionSizeShape( + const Shape& shape, int64 dimension) { + if (dimension < 0 || dimension >= ShapeUtil::Rank(shape)) { + return InvalidArgument("GetDimensionSize dimension out of bounds: %d.", + dimension); + } + + // TODO(b/119580730): Remove this restriction when very large dimension size + // is needed. + if (shape.dimensions(dimension) > std::numeric_limits::max()) { + return InvalidArgument( + "GetDimensionSize's input shape is %s, the %dth dimension exceeds the " + "UINT_MAX limit.", + ShapeUtil::HumanString(shape), dimension); + } + + return ShapeUtil::MakeShape(U32, {}); +} + /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { @@ -2337,6 +2373,52 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(operand.element_type(), dimensions); } +/* static */ StatusOr ShapeInference::InferBroadcastShape( + const Shape& operand_shape, const Shape& output_shape, + absl::Span broadcast_dimensions) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of broadcast")); + TF_RETURN_IF_ERROR(ExpectArray(output_shape, "operand of broadcast")); + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + const int64 output_rank = ShapeUtil::Rank(output_shape); + if (operand_rank > output_rank) { + return InvalidArgument( + "InDim style broadcast must be to an equal or higher ranked shape; " + "operand rank: %lld; output rank: %lld", + operand_rank, output_rank); + } + if (operand_rank != broadcast_dimensions.size()) { + return InvalidArgument( + "Size of broadcast_dimensions has to match operand's rank; operand " + "rank: %lld, size of broadcast_dimensions %u.", + operand_rank, broadcast_dimensions.size()); + } + for (int64 i = 0; i < operand_rank; i++) { + if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] >= output_rank) { + return InvalidArgument("Broadcast dimension %lld is out of bound", + broadcast_dimensions[i]); + } + if (operand_shape.dimensions(i) != + output_shape.dimensions(broadcast_dimensions[i]) && + operand_shape.dimensions(i) != 1) { + return InvalidArgument( + "Input dimension should be either 1 or equal to the output dimension " + "it's broadcasting into; the %lldth operand dimension is %lld, the " + "%lldth output dimension is %lld.", + i, operand_shape.dimensions(i), broadcast_dimensions[i], + output_shape.dimensions(broadcast_dimensions[i])); + } + // Make sure the broadcast dimensions are listed in a strictly increasing + // order. + if (i > 0 && broadcast_dimensions[i - 1] >= broadcast_dimensions[i]) { + return InvalidArgument( + "Broadcast dimensions order is wrong: %d comes after %d.", + broadcast_dimensions[i], broadcast_dimensions.at(i - 1)); + } + } + + return output_shape; +} + /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, absl::Span new_sizes) { @@ -2759,6 +2841,15 @@ Status ValidateScatterDimensionNumbers( } } + // Validate window size. + auto window_size = dim_numbers.update_window_dims_size() + + dim_numbers.inserted_window_dims_size(); + if (window_size != ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Scatter op has window of size %d; doesn't match operand of rank %d.", + window_size, ShapeUtil::Rank(operand_shape)); + } + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. if (dim_numbers.scatter_dims_to_operand_dims_size() != scatter_indices_shape[dim_numbers.index_vector_dim()]) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 96a0ee165d46753da4fef119e7072f66637bf2c4..d94385a04d50baff8156570a09620fd458547936 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -210,6 +210,12 @@ class ShapeInference { static StatusOr InferBroadcastShape( const Shape& operand, absl::Span broadcast_sizes); + // Checks whether the given parameters can form a broadcast. Returns the same + // output_shape if it's legal. + static StatusOr InferBroadcastShape( + const Shape& operand_shape, const Shape& output_shape, + absl::Span broadcast_dimensions); + // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. static StatusOr InferReshapeShape(const Shape& operand, @@ -226,13 +232,6 @@ class ShapeInference { static StatusOr InferConcatOpShape( absl::Span arg_shapes, int64 dimension); - // Infers the shape produced by a kAfterAll. Trivially this shape is always a - // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes - // and checking operand shapes. This method verifies that the operand shapes - // are all TOKENs. - static StatusOr InferAfterAllShape( - absl::Span arg_shapes); - // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -285,6 +284,9 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + static StatusOr InferGetDimensionSizeShape(const Shape& shape, + int64 dimension); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 864ed43118cd066f6ce14cd808b873f137b8414a..4639e32db4d59080a9e85e46983fac61d9e76be9 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); @@ -2649,5 +2673,23 @@ TEST_F(ScatterGatherShapeInferenceTest, << statusor.status(); } +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_InsufficientWindowDims) { + StatusOr statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Scatter op has window of size 4; doesn't match operand of rank 5.")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 56952e3adae59656605a12fd499162504a2a3379..28a30b5ee2dbcb5012804578d4d037c241045309 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -157,4 +157,23 @@ void ScopedShapedBuffer::Deallocate() { } } +ScopedShapedBuffer ScopedShapedBuffer::TakeSubTree(ShapeIndexView index) { + const xla::Shape& sub_on_host_shape = + xla::ShapeUtil::GetSubshape(on_host_shape(), {index}); + const xla::Shape& sub_on_device_shape = + xla::ShapeUtil::GetSubshape(on_device_shape(), {index}); + + ScopedShapedBuffer output(sub_on_host_shape, sub_on_device_shape, + memory_allocator(), device_ordinal()); + auto src_it = buffers().find(index); + auto dst_it = output.buffers().begin(); + while (dst_it != output.buffers().end()) { + dst_it->second = src_it->second; + src_it->second = tensorflow::se::DeviceMemoryBase(nullptr, 0); + ++src_it; + ++dst_it; + } + return output; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e1d26da4a20c0105be304b1a34c81515fcdc6b7f..f5210c9cfa6b29853bcd0f5bfd581ee3e116a509 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -176,6 +176,11 @@ class ScopedShapedBuffer : public ShapedBuffer { // It's the caller's job to ensure that the memory contained therein is freed. TF_MUST_USE_RESULT ShapedBuffer release(); + // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer + // that holds ownership of the subtree. Sets the buffers corresponding to the + // subtree to null in 'this'. + ScopedShapedBuffer TakeSubTree(ShapeIndexView index); + protected: void Deallocate(); diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index d69e6362e91e4696dab3c46d99a981c67b593a1c..ca64bd3c8dd2baa686db2b85c937a034b37ab22b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -107,5 +109,79 @@ TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { // TestAllocator's destructor checks that all memory was freed. } +TEST(ScopedShapedBufferTest, TestTakeSubTree) { + TestAllocator allocator; + + Shape s = ShapeUtil::MakeShape(F32, {1}); + s = xla::ShapeUtil::MakeTupleShape(std::vector(2, s)); + s = xla::ShapeUtil::MakeTupleShape(std::vector(3, s)); + + ScopedShapedBuffer sb(s, s, &allocator, /*device_ordinal=*/0); + sb.buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + TF_ASSERT_OK_AND_ASSIGN( + OwningDeviceMemory m, + allocator.Allocate(/*device_ordinal=*/0, /*size=*/77)); + *buffer = m.Forget(); + }); + ShapeTree buffers = sb.buffers(); + + // Takes a subtree out of 'sb', and verifies the buffers are as expected. + xla::ShapeIndex subtree_index = {1}; + ScopedShapedBuffer output = sb.TakeSubTree(subtree_index); + + output.buffers().ForEachElement([&](const xla::ShapeIndex& sub_index, + const se::DeviceMemoryBase& buffer) { + xla::ShapeIndex orig_index = subtree_index; + for (int i : sub_index) { + orig_index.push_back(i); + } + EXPECT_TRUE(buffers.find(orig_index)->second.IsSameAs(buffer)); + }); + sb.buffers().ForEachElement( + [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) { + if (ShapeIndexView(index).StartsWith(subtree_index)) { + EXPECT_TRUE(buffer.is_null()); + } else { + EXPECT_TRUE(buffers.find(index)->second.IsSameAs(buffer)); + } + }); +} + +// Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_TakeSubTree(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + TestAllocator allocator; + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ScopedShapedBuffer shaped_buffer(shape, shape, /*allocator=*/&allocator, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)shaped_buffer.TakeSubTree(/*index=*/{fan_out / 2}).release(); + } + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_TakeSubTree) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f952e64af2b675b9c0f8a30e9a2bc3c855e34efa..49f0b8f8b72001f07200d3e94828f60fcb0fa8fb 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -95,7 +95,13 @@ class TransferManager { // but need not have the same layout. // // This operation is performed asynchronously on the given stream. It returns - // once the transfer is enqueued. + // once the transfer is enqueued, and may return before the transfer has + // completed. + // + // The caller may free the data structures 'literal' and 'device_buffer' + // immediately after this function returns, however their constituent buffers + // on both host and device must remain valid until the enqueued transfer has + // completed on 'stream'. virtual Status TransferLiteralToDeviceAsync( se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) = 0; diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 79b5c09abb355cd067a4891af558c8c44d80d88e..17cdaa74fc328d156292f5af828d4222a9a01f1f 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -172,7 +172,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - auto module = CreateNewModule("fuse_with_constant_operands"); + auto module = CreateNewVerifiedModule("fuse_with_constant_operands"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( @@ -247,7 +247,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -302,7 +302,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { conv_shape.ValueOrDie(), x, transpose_y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -362,7 +362,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); @@ -428,7 +428,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { conv_shape.ValueOrDie(), transpose_x, y, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); - auto module = CreateNewModule("test_module"); + auto module = CreateNewVerifiedModule("test_module"); HloComputation* entry_computation = module->AddEntryComputation(builder.Build(conv)); FoldTranspose(module.get()); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 811ac55e2dc2939293e62f1ebcd2bce266a12133..50d51eaeb762e208004c1dae3dcc27503f3f94e9 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -148,7 +148,7 @@ TuplePointsToAnalysis::Run(const HloModule* module) { Status TuplePointsToAnalysis::Analyze() { per_instruction_.clear(); - per_instruction_.resize(module_->NumUniqueInstructionIds()); + per_instruction_.reserve(module_->instruction_count()); logical_buffer_aliases_.clear(); logical_buffer_aliases_.resize( @@ -280,6 +280,13 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleAddDependency( + HloInstruction* add_dependency) { + // AddDependency just forwards the value of its zero-th operand. + CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. @@ -756,6 +763,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kScatter || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 30c365053c5dac5af3c559f7c92b11d389d7fff8..0a1d5649d6d69fea12263e6986ce76af62615ec7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -251,6 +252,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; + Status HandleAddDependency(HloInstruction* add_dependency) override; string ToString() const; @@ -315,14 +317,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const PerInstruction* PerInst(const HloInstruction* inst) const { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + LOG(FATAL) << "Expected per-instruction information to already exist"; + } else { + return iter->second.get(); + } } PerInstruction* PerInst(const HloInstruction* inst) { int id = inst->unique_id(); DCHECK_GE(id, 0); - DCHECK_LT(id, per_instruction_.size()); - return &per_instruction_[id]; + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + return per_instruction_.emplace(id, absl::make_unique()) + .first->second.get(); + } else { + return iter->second.get(); + } } std::vector> GetAllUsesOfInstructionAtIndex( @@ -339,7 +350,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { const std::unique_ptr logical_buffer_analysis_; // A map from instruction->unique_id() to - std::vector per_instruction_; + absl::flat_hash_map> per_instruction_; // A map from LogicalBuffer->id() to alias information about that logical // buffer diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e9a07b14ed685fa4388aca583395370a60176cca..561762b5d424ed5f537665be9d67a81dc8bdd56e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -48,7 +48,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { } void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); module_->AddEntryComputation(std::move(computation)); } @@ -264,6 +264,22 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { UnorderedElementsAre(inner_tuple)); } +TEST_F(TuplePointsToAnalysisTest, AddDependency) { + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + auto add_dependency = builder.AddInstruction( + HloInstruction::CreateAddDependency(constant, token)); + BuildModuleAndRunAnalysis(builder.Build()); + + auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency); + EXPECT_EQ(1, points_to_set.size()); + EXPECT_FALSE(points_to_set.IsAmbiguous()); + EXPECT_TRUE(points_to_set.IsDistinct()); + ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant}); +} + TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); @@ -809,7 +825,7 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { class PointsToAnalysisTestBase : public HloTestBase { protected: void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); computation_ = module_->AddEntryComputation(std::move(computation)); } @@ -1010,6 +1026,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) { + const char* hlo_text = R"( + HloModule TensorFlowScatterV1 + + update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) + } + + ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + computation_ = module_->entry_computation(); + RunAnalysis(); + + HloInstruction* operand_param = computation_->parameter_instruction(0); + HloInstruction* indices_param = computation_->parameter_instruction(1); + HloInstruction* updates_param = computation_->parameter_instruction(2); + HloInstruction* scatter = computation_->root_instruction(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + operand_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + indices_param, {}, scatter, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser( + updates_param, {}, scatter, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) { auto builder = HloComputation::Builder(TestName()); @@ -1035,7 +1089,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); @@ -1137,7 +1192,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { return builder.Build(); }; - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); HloComputation* cond_computation = module_->AddEmbeddedComputation(make_cond()); HloComputation* body_computation = @@ -1172,7 +1227,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { auto add = sub_builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - module_ = CreateNewModule(); + module_ = CreateNewUnverifiedModule(); auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); sub_computation->CreateFusionInstruction({add, ones}, HloInstruction::FusionKind::kLoop); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 516754e2110ee50a597818c4a8bcfbfbb76c5cec..65b0f8c804475d8f22fff9798e79c9881a51f1f1 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloVerifiedTestBase { +class TupleSimplifierTest : public HloTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -65,10 +65,10 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { HloInstruction* param2 = builder.AddInstruction( HloInstruction::CreateParameter(2, scalar_shape_, "param2")); builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -78,10 +78,10 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { HloInstruction::CreateParameter(0, tuple_shape_, "param")); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -98,12 +98,12 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), gte); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -125,13 +125,13 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { builder.AddInstruction( HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -157,12 +157,12 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); } - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), element); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -182,12 +182,12 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/true); + Run(module.get(), /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -207,19 +207,19 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), tuple); - Run(module, /*change_expected=*/false); + Run(module.get(), /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { // Verify that the root computation can be excluded - auto module = CreateNewModule(); + auto module = CreateNewVerifiedModule(); HloInstruction* p0; HloInstruction* p1; @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module, /*change_expected=*/true, /*exclude_entry=*/true); + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index 541b117e0299c94de330604ec5c16e20f07c425f..68e2569f66bea9ec1223e454d1ead0efc7b9498e 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" namespace xla { @@ -229,4 +232,96 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, return nullopt; } +// If the only user of this instruction is a get-tuple-element, return that +// get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may +// get a false negative if there are several copies of the same GTE, or there +// are unused GTEs, but we can live with this. +static HloInstruction* GetOnlyGTE(HloInstruction* inst) { + if (inst->user_count() != 1) { + return nullptr; + } + + HloInstruction* user = inst->users().back(); + if (user->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; + } + return user; +} + +optional ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) { + // If we know the exact trip count, it's also the upper bound. + auto exact_trip_count = ComputeWhileLoopTripCount(while_op); + if (exact_trip_count) { + VLOG(2) << "Loop has exact trip count."; + return exact_trip_count; + } + + // There is one more case we know how to handle. If the loop condition only + // looks at one element of the tuple, and the loop body sets this element to a + // constant, there are two options: + // 1) Evaluating the condition on this constant returns true. In this case, + // the loop either executes 0 times, or is an infinite loop, depending on the + // init value. + // 2) Evaluating the condition on this constant returns false. In this case, + // the loop executes 0 or 1 times, depending on the init value. This means + // that, regardless of the init value, the upper bound on the trip count is 1. + + // Check whether the condition depends on a single parameter, and find out + // which. + auto* while_cond = while_op->while_condition(); + auto* while_cond_param = while_cond->parameter_instruction(0); + auto* cond_gte = GetOnlyGTE(while_cond_param); + if (!cond_gte) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // Now check whether this gets set to a constant by the while body. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(3) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + int64 indvar_index = cond_gte->tuple_index(); + auto* while_body_indvar = while_body_root->operand(indvar_index); + if (while_body_indvar->opcode() != HloOpcode::kConstant) { + VLOG(3) << "While body does not set the IV to a constant: " + << while_body_indvar->ToString(); + return nullopt; + } + + // We have a constant. Evaluate the condition on this constant. + HloEvaluator evaluator(/*max_loop_iterations=*/0); + Literal fake_input = Literal::CreateFromShape(while_cond_param->shape()); + TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(), + /*dest_shape_index=*/{indvar_index}, + /*src_shape_index=*/{})); + StatusOr eval_result = + evaluator.Evaluate(*while_cond, {std::move(fake_input)}); + + if (!eval_result.ok()) { + VLOG(2) << "Couldn't evaluate while loop condition."; + return nullopt; + } + + Literal cond_result_pred = std::move(eval_result.ValueOrDie()); + CHECK(ShapeUtil::Equal(cond_result_pred.shape(), + ShapeUtil::MakeShape(PRED, {}))); + + // Per the explanation above, if the evaluated condition returns false, the + // loop executes at most once. + bool cond_returns_true = cond_result_pred.GetFirstElement(); + if (!cond_returns_true) { + VLOG(2) << "Upper bound on the trip count is 1"; + return 1; + } + + VLOG(2) << "Loop has no known upper bound on the trip count."; + return nullopt; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf497f4892b95c927379411468a66d8961465413..ac69a727bd6b403672a676400993fb7d8afc0a55 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -28,6 +28,10 @@ namespace xla { absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, int64 max_value_returned = 128); +// Returns an upper bound on the trip count of the loop if it's statically +// known, nullopt otherwise. +absl::optional ComputeWhileLoopTripCountUpperBound( + HloInstruction *while_op); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1da0fbeac89a93eaaef893e5f25dd3b87cc1d5d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class WhileLoopAnalysisTest : public HloTestBase {}; + +TEST_F(WhileLoopAnalysisTest, SingleIterationUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(-1) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); +} + +TEST_F(WhileLoopAnalysisTest, NoUpperBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + const = s32[] constant(42) + ROOT root = (f32[2], s32[]) tuple(val, const) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, ExactBound) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val = f32[2] get-tuple-element(p_body), index=0 + index = s32[] get-tuple-element(p_body), index=1 + one = s32[] constant(1) + inc = s32[] add(index, one) + ROOT root = (f32[2], s32[]) tuple(val, inc) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] less-than(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] constant(0) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 067cfcc17d65860a249de4d9e31703df12091d3a..8b381dec07397c1427e98bc30511ac21dc577610 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -46,8 +46,9 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( return Status::OK(); } -StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( HloInstruction* while_instr) { + HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); const HloInstruction& init_value = *while_instr->operand(0); @@ -57,24 +58,48 @@ StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( bool changed = false; - for (HloInstruction* invariant_gte : - WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { - int64 index = invariant_gte->tuple_index(); + absl::flat_hash_map> + conditional_gte_index_to_insts = + WhileUtil::GetGTEsMapForWhileConditional(*while_cond); + std::vector invariant_body_gtes = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + for (HloInstruction* invariant_body_gte : invariant_body_gtes) { + int64 index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); - // Should have at least one user that's not while_body_root. - if (invariant_gte->user_count() <= 1) { + // Original value should be a constant. + if (invariant_value.opcode() != HloOpcode::kConstant) { continue; } - if (invariant_value.opcode() == HloOpcode::kConstant) { - auto* constant_instr = + // Sink into the while_body. + // Should have at least one user that's not while_body_root. + if (invariant_body_gte->user_count() > 1) { + HloInstruction* constant_instr = while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_gte, constant_instr, while_body->root_instruction(), + invariant_body_gte, constant_instr, while_body->root_instruction(), index)); changed = true; } + + // Check if there is a corresponding GTE in while_conditional. + auto it = conditional_gte_index_to_insts.find(index); + if (it == conditional_gte_index_to_insts.end()) { + continue; + } + + for (HloInstruction* invariant_cond_gte : it->second) { + // Should have at least one user. + if (invariant_cond_gte->user_count() > 0) { + HloInstruction* constant_instr = while_cond->AddInstruction( + invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR( + invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); + changed = true; + } + } } return changed; @@ -115,10 +140,8 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { } for (HloInstruction* while_instr : while_instrs) { - // We only sink into while loop bodies, but this can be extended to - // transform conditions as well. TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileBody(while_instr)); + TrySinkingConstantsIntoWhileLoop(while_instr)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 577bad6c7062d2ee40271e407e8eed7655fa13bf..a866bc1264b4013bb7530b5e02b546e6f78d676b 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { // Sinks while loop invariant values that happen to be constants into the while -// loop body. This is probably not a win in isolation but may unlock further -// optimizations like constant folding. +// loop body and conditional. This is probably not a win in isolation but may +// unlock further optimizations like constant folding. // // state = (..., const, ...) // while (pred(state)) { @@ -46,22 +46,19 @@ namespace xla { // tuple trivially loop invariant. WhileLoopSimplifier will later get rid of // `v`. // -// We only sink into while loop bodies, but this can be extended to transform -// conditions as well. -// // TODO(b/79121449): We should also sink broadcasts of constants. class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; absl::string_view name() const override { - return "while-loop-invariant-code-motion"; + return "while-loop-constant-sinking"; } StatusOr Run(HloModule* module) override; private: - StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); + StatusOr TrySinkingConstantsIntoWhileLoop(HloInstruction* while_instr); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0e7667de832c54f647d071e3c9563091d0f994aa..75d406435b6f58faecc86b82c33e9e2dd6bccbea 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -114,7 +114,7 @@ HloModule ModuleWithWhile body { p_b = (f32[2],(f32[2],f32[2])) parameter(0) - p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.0 = f32[2] get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=0 p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 @@ -242,5 +242,178 @@ ENTRY entry { } } } + +TEST_F(WhileLoopConstantSinkingTest, ConditionalSinkConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[]) p_body), index=1 + ROOT root = (f32[],f32[]) tuple(add, p_body.1) +} + +condition { + p_cond = (f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[]) p_cond), index=1 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + while_init = (f32[],f32[]) tuple(const_0, const_1) + ROOT while = (f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalTupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[],(f32[],f32[])) parameter(0) + p_b.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_b), index=0 + p_b.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_b), index=1 + p_b.1.0 = f32[] get-tuple-element((f32[],f32[]) p_b.1), index=0 + add = f32[] add(p_b.0, p_b.1.0) + ROOT root = (f32[],(f32[],f32[])) tuple(add, p_b.1) +} + +condition { + p_c = (f32[],(f32[],f32[])) parameter(0) + p_c.0 = f32[] get-tuple-element((f32[],(f32[],f32[])) p_c), index=0 + p_c.1 = (f32[],f32[]) get-tuple-element((f32[],(f32[],f32[])) p_c), index=1 + p_c.1.1 = f32[] get-tuple-element((f32[],f32[]) p_c.1), index=1 + ROOT result = pred[] less-than(p_c.0, p_c.1.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = (f32[], f32[]) constant((f32[], f32[]) (1, 10)) + while_init = (f32[],(f32[],f32[])) tuple(const_0, const_1) + ROOT while = (f32[],(f32[],f32[])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::Lt(_, op::GetTupleElement(op::Constant()))); +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalDontCreateDeadConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + ROOT result = pred[] less-than(p_cond.0, p_cond.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(10) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); + for (const HloInstruction* inst : while_condition->instructions()) { + if (inst->opcode() == HloOpcode::kConstant) { + EXPECT_GT(inst->user_count(), 0); + } + } +} + +TEST_F(WhileLoopConstantSinkingTest, ConditionalMultipleSameIndexGTEs) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[],f32[],f32[]) parameter(0) + p_body.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=0 + const = f32[] constant(1) + add.0 = f32[] add(p_body.0, const) + p_body.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=1 + add.1 = f32[] add(p_body.1, const) + p_body.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_body), index=2 + ROOT root = (f32[],f32[],f32[]) tuple(add.0, add.1, p_body.2) +} + +condition { + p_cond = (f32[],f32[],f32[]) parameter(0) + p_cond.0 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=0 + p_cond.2 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.0 = pred[] less-than(p_cond.0, p_cond.2) + p_cond.1 = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=1 + p_cond.2.c = f32[] get-tuple-element((f32[],f32[],f32[]) p_cond), index=2 + lt.1 = pred[] less-than(p_cond.1, p_cond.2.c) + ROOT result = pred[] and(lt.0, lt.1) +} + +ENTRY entry { + const_0 = f32[] constant(0) + const_1 = f32[] constant(0) + const_2 = f32[] constant(12) + while_init = (f32[],f32[],f32[]) tuple(const_0, const_1, const_2) + ROOT while = (f32[],f32[],f32[]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_condition = module->GetComputationWithName("condition"); + EXPECT_THAT(while_condition->root_instruction(), + op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 9795b2830b6d9add82b89ac76b5438ddc3d2bfe8..41011176ffa91e885bc58364d1fb19617d3518ad 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -19,7 +19,9 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -143,6 +145,12 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( string while_instr_name = while_instr->ToString(print_no_metadata); VLOG(2) << "Trying to hoist from " << while_instr_name; + auto maybe_upper_bound = ComputeWhileLoopTripCountUpperBound(while_instr); + if (maybe_upper_bound && *maybe_upper_bound <= 1) { + VLOG(2) << "Loop has a trip count of at most 1, skipping."; + return false; + } + HloComputation* while_body = while_instr->while_body(); // Maps instructions in the while body to instructions hoisted outside the @@ -180,6 +188,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( return false; } + // LICM in the presence of domain instructions is complex, bail. + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + return false; + } + } + // instructions_to_replace[i] is hoisted into a loop invariant instruction // replacement_instructions[i]. std::vector instructions_to_replace; @@ -193,6 +208,37 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( continue; } + if (!hoist_size_inflating_ops_) { + // Check that hoisting the instruction doesn't cause a significant memory + // blow-up. LICM extends the live-range of the output of the hoisted + // instruction to be the entire while loop, which may be problematic on + // platforms where memory is limited. This can be especially harmful if + // the instruction has a significantly larger output than its input, e.g. + // kIota, kBroadcast or kConstant. + int64 input_size = 0, output_size = 0; + + for (auto* operand : instruction->operands()) { + ShapeUtil::ForEachSubshape( + operand->shape(), + [&input_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + input_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&output_size](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + output_size += ShapeUtil::ByteSizeOfElements(subshape); + } + }); + + if (output_size > input_size) { + continue; + } + } + auto is_invariant = [&](HloInstruction* op) { return hoisted_instructions.find(op) != hoisted_instructions.end() || unhoisted_invariant_instructions.count(op) || diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 3031899f71e0fd77f20448d9d7489798af01615c..bd6232dc0a988775a0490abbf6125daad8476295 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -34,8 +34,14 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { // Setting `hoist_constants` to false can be help if LICM is run in the mid // level HLO pipeline because hoisting constants out of while loop bodies can // break optimizations like constant folding. - explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) - : hoist_constants_(hoist_constants) {} + // Setting `hoist_size_inflating_ops` to false will forbid hoisting + // instructions where the size of the output(s) is larger than the size of the + // input(s). This is useful on platforms on which it's important to prevent + // blow-ups in memory size. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false, + bool hoist_size_inflating_ops = true) + : hoist_constants_(hoist_constants), + hoist_size_inflating_ops_(hoist_size_inflating_ops) {} ~WhileLoopInvariantCodeMotion() override = default; absl::string_view name() const override { @@ -49,6 +55,7 @@ class WhileLoopInvariantCodeMotion : public HloModulePass { HloInstruction* while_instr); bool hoist_constants_; + bool hoist_size_inflating_ops_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b713c438bd7fcb2053709b0624f58ed..8e7c4bc8828552e197b41f874c070d496b85a382 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -26,7 +26,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { +class WhileLoopInvariantCodeMotionTest : public HloTestBase { public: // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. @@ -58,6 +58,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( } TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -76,19 +77,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -100,6 +100,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { } TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -135,19 +136,18 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -173,6 +173,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistTriviallyLoopVaryingComputation) { // Basic negative test: the add expression is not loop invariant. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -189,20 +190,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); @@ -210,6 +211,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -228,25 +230,26 @@ TEST_F(WhileLoopInvariantCodeMotionTest, builder.AddInstruction( HloInstruction::CreateTuple({gte_1, gte_0, add_result})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); Shape while_shape = @@ -267,7 +270,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -277,14 +280,14 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); ASSERT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -294,6 +297,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the // bitcast either. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); auto token_shape = ShapeUtil::MakeTokenShape(); @@ -317,7 +321,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, out_token})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); @@ -327,15 +331,15 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { auto* init_value = builder.AddInstruction( HloInstruction::CreateTuple({scalar_param, scalar_param, token})); auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0)); - module().AddEntryComputation(builder.Build()); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); EXPECT_THAT(while_inst->while_body()->instructions(), @@ -346,6 +350,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { // The bitcast's user can be hoisted, so hoist the bitcast too. + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); Shape while_shape = @@ -367,21 +372,20 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); - return module().AddEmbeddedComputation(builder.Build()); + return m->AddEmbeddedComputation(builder.Build()); }(); HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); - HloComputation* entry_computation = - module().AddEntryComputation(builder.Build()); + HloComputation* entry_computation = m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_TRUE(simplified_loop); HloInstruction* transformed_while; @@ -396,6 +400,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { } TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); @@ -416,22 +421,23 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { builder.AddInstruction( HloInstruction::CreateTuple({gte_0, gte_1, add_result})); - while_body = module().AddEmbeddedComputation(builder.Build()); + while_body = m->AddEmbeddedComputation(builder.Build()); } HloComputation::Builder builder(TestName()); auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto m = CreateNewVerifiedModule(); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); @@ -439,7 +445,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { HloComputation::Builder builder(TestName() + ".passthrough"); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + HloComputation* result = m->AddEmbeddedComputation(builder.Build()); result->AddInstruction( HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); @@ -450,11 +456,11 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { auto* init_value = builder.AddInstruction( HloInstruction::CreateParameter(0, while_shape, "init_value")); builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); + while_shape, MakeAlwaysTrueComputation(while_shape, m.get()), while_body, + init_value)); + m->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); EXPECT_FALSE(simplified_loop); } @@ -482,14 +488,14 @@ ENTRY entry { )"; TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN( bool simplified_loop, - WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(m.get())); EXPECT_TRUE(simplified_loop); - HloComputation* while_body = module().GetComputationWithName("wide.body"); + HloComputation* while_body = m->GetComputationWithName("wide.body"); ASSERT_NE(while_body, nullptr); // We expect the while body to be the equivalent of: @@ -523,10 +529,98 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { } TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { - ParseAndVerifyModule(kConstantHoistingTestCase); + auto m = ParseAndReturnVerifiedModule(kConstantHoistingTestCase).ValueOrDie(); TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopInvariantCodeMotion{}.Run(&module())); + WhileLoopInvariantCodeMotion{}.Run(m.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoNotHoistOutOfSingleIteration) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], f32[2], f32[2], s32[]) parameter(0) + val.0 = f32[2] get-tuple-element(p_body), index=0 + val.1 = f32[2] get-tuple-element(p_body), index=1 + add = f32[2] add(val.0, val.1) + const = s32[] constant(-1) + ROOT root = (f32[2], f32[2], f32[2], s32[]) tuple(val.0, val.1, add, const) + } + + condition { + p_cond = (f32[2], f32[2], f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=3 + const = s32[] constant(42) + ROOT result = pred[] equal-to(gte, const) + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], f32[2], f32[2], s32[]) tuple(param.0, param.0, param.0, param.1) + ROOT while = (f32[2], f32[2], f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(module.get())); + EXPECT_FALSE(simplified_loop); +} + +const char* const kInflatingTestCase = R"( +HloModule ModuleWithWhile + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +body { + p_body = (f32[]) parameter(0) + iota = f32[1024, 1024] iota(), iota_dimension=0 + add = f32[1024, 1024] add(iota, iota) + constant = f32[] constant(1.0) + reduce = f32[] reduce(f32[1024, 1024] add, f32[] constant), dimensions={0,1}, to_apply=mul + ROOT root = (f32[]) tuple(reduce) +} + +condition { + p_cond = (f32[]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + param = f32[] parameter(0) + while_init = (f32[]) tuple(param) + ROOT while = (f32[]) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsInflatingByDefault) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true).Run(m.get())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = m->GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + EXPECT_THAT(while_body->instructions(), Not(Contains(op::Iota()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, NoHoistInflating) { + auto m = ParseAndReturnVerifiedModule(kInflatingTestCase).ValueOrDie(); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion(/*hoist_constants=*/true, + /*hoist_size_inflating_ops=*/false) + .Run(m.get())); EXPECT_FALSE(simplified_loop); } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 630d71e5ca25e9d282ce6283284a32d6f725a193..d30f67dd8110b88166fe807762fb653190ec00bc 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -19,41 +19,19 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" namespace xla { +namespace m = match; using absl::optional; - -// Determines whether the given instruction is a send/recv node, or has a -// subcomputation which contains a send/recv node. -static bool IsOrContainsSendOrRecv(const HloInstruction* instr); - -// Determines whether the given computation contains a send or recv node. -static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto* instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr)) { - return true; - } - } - return false; -} - -static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kSendDone || - instr->opcode() == HloOpcode::kRecv || - instr->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (const auto& subcomp : instr->called_computations()) { - if (ContainsSendOrRecv(subcomp)) { - return true; - } - } - return false; -} +using hlo_query::ContainsInstrWithOpcode; // Tries to remove elements in a while loop's tuple that aren't used within the // loop. @@ -253,7 +231,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond), /*extras=*/{}); + make_while_computation_replacements(while_cond)); std::unordered_map> while_body_replacements = make_while_computation_replacements(while_body); @@ -266,8 +244,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements), - /*extras=*/{}); + while_body->CloneWithReplacements(std::move(while_body_replacements)); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -329,6 +306,147 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return true; } +// Removes each loop parameter (i.e. member of the while loop tuple) that is a +// constant and is the same in the while loop body and the while loop init. +static StatusOr TryRemoveConstantParams(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + + absl::flat_hash_set constant_tuple_indices; + const auto& while_shape = while_init->shape(); + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + auto* init_elem = while_init->operand(i); + auto* body_elem = while_body_root->operand(i); + if (init_elem->opcode() == HloOpcode::kConstant && + body_elem->opcode() == HloOpcode::kConstant && + init_elem->literal() == body_elem->literal()) { + constant_tuple_indices.insert(i); + } + } + + if (constant_tuple_indices.empty()) { + return false; + } + + // OK, we found some constant elements of the while parameter! Eliminate + // them. + std::vector new_while_shape_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + new_while_shape_elems.push_back(while_shape.tuple_shapes(i)); + } + } + Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + // Returns a new tuple without the elements of constant_tuple_indices. + auto remove_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), while_shape)); + + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (!constant_tuple_indices.count(i)) { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, i))); + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + auto add_constant_elems = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + + std::vector tuple_elems; + int64 j = 0; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + if (constant_tuple_indices.count(i)) { + tuple_elems.push_back(while_init->mutable_operand(i)); + } else { + tuple_elems.push_back( + add_new_instr(HloInstruction::CreateGetTupleElement( + while_shape.tuple_shapes(i), instr, j))); + ++j; + } + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Special case: constant_tuple_indices covers the whole while parameter, so + // the new while shape is the empty tuple. In this case, the value of the + // while loop is simply equal to the value of `init`. + // + // It's unfortunate to special-case this, but it's simpler than the + // alternative. The problem is that if our while parameter has no + // non-constant elems, the tuple returned by `add_constant_elems` won't depend + // on instr (the loop body/cond parameter), and therefore + // CloneWithReplacementPairs will *leave the parameter out entirely*, creating + // invalid HLO. + if (ShapeUtil::IsEmptyTuple(new_while_shape)) { + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init)); + return true; + } + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + add_constant_elems(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + remove_constant_elems( + add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, + add_constant_elems( + computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + add_new_instr(remove_constant_elems(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + // Tries to remove a while loop from the graph. // // - Loops with trip count of 0 can be replaced by the loop's "init" value. @@ -408,16 +526,14 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { // performance by forcing us to copy constants. absl::flat_hash_map index_to_constant; for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && instr->operand(0) == while_body_param && - ShapeUtil::IsScalar(instr->shape())) { - auto tuple_element = while_init->operand(i); - if (tuple_element->IsConstant()) { - VLOG(3) << "Found loop invariant tuple element " << i << " " - << tuple_element->ToString(); - index_to_constant[i] = tuple_element; - } + const HloInstruction* init_tuple_elem = nullptr; + if (Match(root_operands[i], + m::GetTupleElement(m::Op().Is(while_body_param), i) + .WithShape(m::Shape().IsScalar())) && + Match(while_init->operand(i), m::Constant(&init_tuple_elem))) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << init_tuple_elem->ToString(); + index_to_constant[i] = init_tuple_elem; } } @@ -458,6 +574,409 @@ static StatusOr TryPropagateConstant(HloInstruction* while_op) { return changed_cond || changed_body; } +// Converts a flat list of instructions into a tuple of the desired shape. For +// example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns +// a tuple of value ((A, B), C). +// +// desired_shape must be a tuple. (This precondition allows us to return a +// unique_ptr rather than a raw ptr.) +static std::unique_ptr UnflattenTupleInstr( + absl::Span instrs, const Shape& desired_shape, + std::vector>* new_instrs) { + CHECK(ShapeUtil::IsTuple(desired_shape)) + << ShapeUtil::HumanString(desired_shape); + + // For each child shape in `desired_shape`, slice out the correct number of + // `instrs` and call UnflattenTupleInstr recursively. At each step we remove + // elements from `instrs` so that it only contains instructions we have not + // yet processed. + std::vector elems; + for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) { + const Shape& subshape = desired_shape.tuple_shapes(i); + if (!ShapeUtil::IsTuple(subshape)) { + elems.push_back(instrs[0]); + instrs.remove_prefix(1); + continue; + } + + // Count the number of leaf nodes underneath desired_shape[i]. + int64 num_leaves = 0; + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + ++num_leaves; + } + }); + + std::unique_ptr subinstr = + UnflattenTupleInstr(instrs.subspan(0, num_leaves), + desired_shape.tuple_shapes(i), new_instrs); + elems.push_back(subinstr.get()); + new_instrs->push_back(std::move(subinstr)); + instrs.remove_prefix(num_leaves); + } + return HloInstruction::CreateTuple(elems); +} + +// Builds a vector whose elements are the values in the flattened tuple for +// `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the +// vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C). +static std::vector GetFlatTupleElems( + HloInstruction* instr, + std::vector>* new_instrs) { + const auto& shape = instr->shape(); + if (!ShapeUtil::IsTuple(shape)) { + return {instr}; + } + std::vector elems; + for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + new_instrs->push_back( + HloInstruction::CreateGetTupleElement(subshape, instr, i)); + auto* gte = new_instrs->back().get(); + auto flattened_subshape = GetFlatTupleElems(gte, new_instrs); + elems.insert(elems.end(), flattened_subshape.begin(), + flattened_subshape.end()); + } + return elems; +} + +static StatusOr TryFlattenNestedTuples(HloInstruction* while_op) { + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + if (!ShapeUtil::IsNestedTuple(while_shape)) { + return false; + } + + std::vector flattened_shape_elems; + ShapeUtil::ForEachSubshape(while_shape, + [&](const Shape& s, const ShapeIndex& /*index*/) { + if (!ShapeUtil::IsTuple(s)) { + flattened_shape_elems.push_back(s); + } + }); + Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems); + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto nested = [&](HloInstruction* instr) { + std::vector gtes; + const Shape& flat_shape = instr->shape(); + for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) { + gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement( + flat_shape.tuple_shapes(i), instr, i))); + } + auto nested_instr = + UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs); + CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape)) + << ShapeUtil::HumanString(nested_instr->shape()) << " vs " + << ShapeUtil::HumanString(while_shape); + return nested_instr; + }; + + auto flattened = [&](HloInstruction* instr) { + return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs)); + }; + + // Create a new while-condition computation, where parameter 0 has flat shape + // but all uses of it go through the nested shape. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Create a new while-body computation, where parameter 0 has a flat shape and + // all uses of it go through the nested shape, and where the root has a flat + // shape constructed from the old nested root. + std::unique_ptr new_while_body = + while_body->CloneWithReplacementPairs( + { + while_body->parameter_instruction(0), + nested(add_new_instr(HloInstruction::CreateParameter( + 0, flattened_shape, + while_body->parameter_instruction(0)->name()))), + }, + { + while_body->root_instruction(), + flattened(add_new_instr(while_body->root_instruction()->Clone())), + }); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile( + flattened_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + computation->AddInstruction(flattened(while_init))))))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return true; +} + +// Tries to merge loop induction variables of a given type. +// +// In this pass we're only concerned with elements of the loop's tuple that +// are effective-scalars of type `elem_ty`. Some terminology: +// +// - The trip counter is the first element of the loop's tuple that starts at +// 0 and does x++ on each iteration. +// +// - An induction variable is an element of the loop's tuple that is not the +// trip counter and does `x += ` on each iteration of the loop. +// Negative constants are OK. +// +// This pass adds a trip counter if one isn't already present, then replaces +// each induction variable with +// +// + * . +// +// This reduces the number of scalar operations in the loop, which is important +// e.g. on GPUs, where each scalar operation is nontrivially expensive because +// it's a separate kernel launch. +// +// Returns the new loop if a change was made, or null if no change was made. +// Note that the new loop is not a valid replacement for the old loop; it may +// need to be wrapped in a tuple that changes its shape. We return the loop +// itself so that you can call TryMergeInductionVariables in a loop, once for +// each integral type elem_ty. +static StatusOr TryMergeInductionVariables( + HloInstruction* while_op, PrimitiveType elem_ty) { + CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty); + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + auto* while_init = while_op->mutable_operand(0); + auto* while_body = while_op->while_body(); + auto* while_cond = while_op->while_condition(); + auto* while_body_root = while_body->root_instruction(); + if (while_init->opcode() != HloOpcode::kTuple || + while_body_root->opcode() != HloOpcode::kTuple) { + return nullptr; + } + + TF_RET_CHECK(while_cond->num_parameters() == 1); + TF_RET_CHECK(while_body->num_parameters() == 1); + TF_RET_CHECK( + ShapeUtil::Compatible(while_init->shape(), while_body_root->shape())); + Shape while_shape = while_init->shape(); + + // The tuple index of the trip counter, if one is present. + absl::optional trip_counter; + // Maps the tuple index of each induction variable to its constant increment. + absl::flat_hash_map induction_vars; + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + HloInstruction* constant; + if (!Match(while_body_root->mutable_operand(i), + m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i), + m::ConstantScalar(&constant)) + .WithShape(m::Shape().WithElementType(elem_ty)))) { + continue; + } + if (!trip_counter && constant->literal().IsAll(1) && + while_init->operand(i)->IsConstant() && + while_init->operand(i)->literal().IsAll(0)) { + VLOG(10) << "Found existing trip counter at index " << i; + trip_counter = i; + } else { + VLOG(10) << "Found induction variable at index " << i; + induction_vars.emplace(i, Cast(constant)); + } + } + + // There's only something to simplify if we can either: + // + // - combine one or more induction vars with an existing trip counter, or + // - replace two or more induction variables with a new trip counter. + // + // Put another way, there's only something to simplify if the number of + // induction vars plus the number of existing trip counters (0 or 1) is >= 2. + if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) { + return nullptr; + } + + // OK, we're going to do the transformation! Set up some helpers. + + // `new_instrs` holds instructions created outside of a computation for + // cloning. Elements added here just need to live until the end of the + // relevant CloneWithReplacement call. + std::vector> new_instrs; + auto add_new_instr = [&](std::unique_ptr instr) { + new_instrs.push_back(std::move(instr)); + return new_instrs.back().get(); + }; + + auto add_binary_op = [&](const Shape& shape, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) { + // Reshape lhs/rhs to the output shape if necessary. This deals with the + // fact that induction variables need only be effective scalars, not true + // scalars. + if (!ShapeUtil::Compatible(shape, lhs->shape())) { + lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs)); + } + if (!ShapeUtil::Compatible(shape, rhs->shape())) { + rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs)); + } + return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs)); + }; + + auto add_gte = [&](HloInstruction* src, int64 idx) { + return add_new_instr(HloInstruction::CreateGetTupleElement( + src->shape().tuple_shapes(idx), src, idx)); + }; + + // Our new while loop will have the same shape as the old while loop, except + // we'll add a trip counter to the end if it wasn't originally present. + Shape new_while_shape = while_shape; + bool added_trip_counter = false; + if (!trip_counter) { + VLOG(10) << "Adding new trip counter to end of loop's tuple."; + trip_counter = new_while_shape.tuple_shapes_size(); + *new_while_shape.add_tuple_shapes() = + ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{}); + added_trip_counter = true; + } + + // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with + // shape `while_body->shape()` and where the induction variables are "reified" + // (i.e. they have value + * ). + auto convert_to_old_form = [&](HloInstruction* instr) { + CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape)); + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + const auto& elem_shape = while_shape.tuple_shapes(i); + if (!induction_vars.count(i)) { + tuple_elems.push_back(add_gte(instr, i)); + continue; + } + tuple_elems.push_back(add_binary_op( + elem_shape, HloOpcode::kAdd, add_gte(instr, i), + add_binary_op(elem_shape, HloOpcode::kMultiply, + add_gte(instr, *trip_counter), + add_new_instr(induction_vars.at(i)->Clone())))); + } + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Converts `root` into a tuple of the "new" form -- that is, to a tuple with + // shape `new_while_shape` and where the induction variables (but not trip + // counters) are replaced with their unchanging values. + auto convert_to_new_form = [&](HloInstruction* old_root, + HloParameterInstruction* loop_body_param) { + CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape)); + std::vector tuple_elems; + + // In the new form, induction variables come from `init`, everything else + // (including the trip counter if it's not one we created ourselves) comes + // from the `root` tuple unmodified. + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back( + add_gte((induction_vars.count(i) ? loop_body_param : old_root), i)); + } + // If we created a trip counter ourselves, add 1 to it in the next + // iteration. + if (added_trip_counter) { + tuple_elems.push_back(add_binary_op( + new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd, + add_gte(loop_body_param, *trip_counter), + add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::One(elem_ty))))); + } + + return HloInstruction::CreateTuple(tuple_elems); + }; + + // Creates a new init tuple, which is the same as the old init tuple except if + // we added a trip counter, it's set to 0. + auto get_new_while_init = [&](HloInstruction* init) { + CHECK(ShapeUtil::Compatible(init->shape(), while_shape)); + if (!added_trip_counter) { + return init; + } + std::vector tuple_elems; + for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) { + tuple_elems.push_back(add_gte(init, i)); + } + tuple_elems.push_back(add_new_instr( + HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty)))); + return add_new_instr(HloInstruction::CreateTuple(tuple_elems)); + }; + + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacementPairs({ + while_cond->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_cond->parameter_instruction(0)->name()))), + }); + + // Creating the new while body proceeds in two steps. First we convert the + // users of the parameter to the old form. Then as a second + // CloneWithReplacement operation we convert the root to the new form. We + // have to do this in two steps because the new root needs to use the new + // param0, and during the first clone operation, only the *old-form* param0 is + // accessible. + // + // We have to add temp_new_while_body to the module because cloning a + // computation touches the module (to get its NameUniquer). + HloComputation* temp_new_while_body = + module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({ + while_body->parameter_instruction(0), + convert_to_old_form(add_new_instr(HloInstruction::CreateParameter( + 0, new_while_shape, + while_body->parameter_instruction(0)->name()))), + })); + std::unique_ptr new_while_body = + temp_new_while_body->CloneWithReplacementPairs({ + temp_new_while_body->root_instruction(), + convert_to_new_form( + add_new_instr(temp_new_while_body->root_instruction()->Clone()), + Cast( + temp_new_while_body->parameter_instruction(0))), + }); + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body)); + + // Create the final while loop, and add any new instructions created to + // `computation`. + new_instrs.clear(); + auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + get_new_while_init(while_init))); + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + while_op, convert_to_old_form(new_while))); + for (auto& instr : new_instrs) { + computation->AddInstruction(std::move(instr)); + } + return new_while; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -478,32 +997,77 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { for (HloInstruction* while_op : while_ops) { // We can't remove while loops that contain send/recv nodes, because we rely // on the particular loop structure around the node matching on the send and - // recv sides. Removing dead while params requires us to remove the loop + // recv sides. Other while simplifications require us to remove the loop // and replace it with a new one, so we can't do that either. - if (ContainsSendOrRecv(while_op->while_body()) || - ContainsSendOrRecv(while_op->while_condition())) { + if (ContainsInstrWithOpcode(while_op->while_body(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kSend, HloOpcode::kSendDone, + HloOpcode::kRecv, HloOpcode::kRecvDone})) { VLOG(2) << "Not attempting to simplify while loop because it contains a " "send/recv node: " << while_op->ToShortString(); continue; } - StatusOr result = TryPropagateConstant(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op)); + changed |= result; + + TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op)); + changed |= result; + if (result) { + // Don't continue simplifying after successfully removing the while loop + // -- that would result in use-after-free nastiness. + continue; + } + + // TODO(b/119281462): Cowardly refuse to perform any of the following + // optimizations in the presence of kDomain instructions. It seems that + // modifying a while loop's tuple doesn't work when kDomain is present. + if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) || + ContainsInstrWithOpcode(while_op->while_condition(), + {HloOpcode::kDomain})) { + continue; + } + + // Each of the optimizations below modifies the while loop itself if it's + // successful, meaning that `while_op` is no longer valid after one of these + // transformations returns true. - result = TryRemoveWhileLoop(while_op); - TF_RETURN_IF_ERROR(result.status()); - if (result.ValueOrDie()) { - changed = true; - // Don't try to remove dead while params after successfully removing the - // while loop -- that would result in use-after-free nastiness. + TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op)); + changed |= result; + if (result) { continue; } - result = TryRemoveDeadWhileParams(while_op); - TF_RETURN_IF_ERROR(result.status()); - changed |= result.ValueOrDie(); + TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op)); + changed |= result; + if (result) { + continue; + } + + TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op)); + changed |= result; + if (result) { + continue; + } + + bool merged_induction_vars = false; + // Notably missing from this list are S16 and U16. These don't currently + // work because S/U16 literals are not implemented. + for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) { + TF_ASSIGN_OR_RETURN(auto* new_while_op, + TryMergeInductionVariables(while_op, elem_ty)); + if (new_while_op) { + while_op = new_while_op; + changed = true; + merged_induction_vars = true; + } + } + if (merged_induction_vars) { + continue; + } } XLA_VLOG_LINES(3, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 0bc5a0107bbcfb3b29a01d593fb79b89a863e49b..a378f179c63c788cd205ddbb784dee0e6b2106d7 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,11 +25,22 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. +// // - A while loop with static trip count of 1 is replaced by its body (sans // loop). +// // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // +// - If the while loop's parameter is a nested tuple, it's flattened to a +// single-level tuple. This is good because it usually reduces the number of +// kTuple instructions, but also because it unlocks additional optimizations +// (e.g. removing unused loop parameters). +// +// Flattening nested while loop tuples adds a whole mess of likely unnecessary +// kGetTupleElement and kTuple operations to the graph. We expect that tuple +// simplifier will be run afterwards. +// class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 1c892ba179ec67ccc9dbfe93d925551d6977ba15..4950e8269e9cf0723d717bd1734518d104c0c9f2 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -17,28 +17,45 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { +using ::testing::_; namespace op = xla::testing::opcode_matchers; -class WhileLoopSimplifierTest : public HloVerifiedTestBase { +// Returns the first kWhile instruction within m's entry computation. +HloInstruction* FindFirstWhile(HloModule* m) { + const auto& instrs = m->entry_computation()->instructions(); + return *absl::c_find_if(instrs, [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); +} + +class WhileLoopSimplifierTest : public HloTestBase { protected: // Makes an HloModule that contains a loop with `num_iters` iteration. - void MakeModuleWithSimpleLoop(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoop(int num_iters); // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to // the loop-condition through an element of a tuple which is the // loop-condition parameter. - void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); + TF_MUST_USE_RESULT std::unique_ptr + MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string_template = R"( HloModule SimpleLoop SimpleLoop.body { @@ -67,10 +84,11 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } -void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( +std::unique_ptr +WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( int num_iters) { string hlo_string_template = R"( HloModule SimpleLoopWithIndirectLoopBound @@ -104,60 +122,55 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( string hlo_string = absl::StrReplaceAll( hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); - ParseAndVerifyModule(hlo_string); + return ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationTupleElementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationTupleELementLoopBoundSimplified) { - MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); - HloModule* the_module = &module(); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); - EXPECT_THAT(the_module->entry_computation()->root_instruction(), + auto m = MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply(), op::Constant())); } TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/2); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/2); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithControlDependencySimplifiedDependencyPreserved) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -166,9 +179,8 @@ TEST_F(WhileLoopSimplifierTest, // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -179,13 +191,12 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -194,7 +205,7 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and @@ -202,16 +213,15 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { - MakeModuleWithSimpleLoop(/*num_iters=*/1); - HloModule* the_module = &module(); - HloComputation* computation = the_module->entry_computation(); + auto m = MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloComputation* computation = m->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); auto token = while_body->AddInstruction(HloInstruction::CreateToken()); while_body->AddInstruction(HloInstruction::CreateInfeed( ShapeUtil::MakeShape(F32, {1}), token, "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A non-tuple shaped loop shouldn't be simplified or crash the compiler. @@ -236,8 +246,8 @@ TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // A while loop that does nothing else besides swapping tuple elements @@ -268,8 +278,8 @@ TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each @@ -297,8 +307,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. @@ -320,8 +330,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't @@ -348,8 +358,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be @@ -390,16 +400,15 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { } )"; - ParseAndVerifyModule(hlo_string); - HloModule* the_module = &module(); - EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); // The original while instruction is still left in the module as a dead // instruction, find a while instruction with a different name as the new // while instruction. HloInstruction* new_while_op = - *std::find_if(the_module->entry_computation()->instructions().begin(), - the_module->entry_computation()->instructions().end(), + *std::find_if(m->entry_computation()->instructions().begin(), + m->entry_computation()->instructions().end(), [&](const HloInstruction* instr) { return (instr->opcode() == HloOpcode::kWhile && instr->name() != "while"); @@ -440,8 +449,8 @@ TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, @@ -473,8 +482,8 @@ TEST_F(WhileLoopSimplifierTest, } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); } TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { @@ -505,8 +514,233 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { } )"; - ParseAndVerifyModule(hlo_string); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_FALSE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) { + const string hlo_string = R"( + HloModule Test + Body { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ta = (s32[1]) get-tuple-element(param), index=0 + a = s32[1] get-tuple-element(ta), index=0 + a.1 = s32[1] add(a, a) + tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1 + ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + } + Cond { + param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + b = s32[2] constant({0,1}) + c = s32[3] constant({0,1,2}) + d = s32[4] constant({0,1,2,3}) + ta = (s32[1]) tuple(a) + td = (s32[4]) tuple(d) + tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td) + init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd) + ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape flat_tuple = + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])") + .ValueOrDie(); + SCOPED_TRACE(m->ToString()); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + flat_tuple)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))") + .ValueOrDie())); +} + +// Edge-case: All elements of the loop carry are constants which can be removed, +// leaving us with a nullary loop. This is a special case, we just replace the +// loop with its init. +TEST_F(WhileLoopSimplifierTest, OnlyConstantsInLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1]) parameter(0) + a = s32[1] constant({0}) + ROOT tuple = (s32[1]) tuple(a) + } + Cond { + param = (s32[1]) parameter(0) + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + a = s32[1] constant({0}) + init = (s32[1]) tuple(a) + ROOT while = (s32[1]) while(init), condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, RemoveConstantFromLoopCarry) { + const string hlo_string = R"( + HloModule Test + Body { + param = (s32[1], s32[2], s32[3]) parameter(0) + a = s32[1] get-tuple-element(param), index=0 + a.1 = s32[1] add(a, a) + b = s32[2] constant({1,1}) + c = s32[3] constant({10,10,10}) + ROOT tuple = (s32[1], s32[2], s32[3]) tuple(a.1, b, c) + } + Cond { + param = (s32[1], s32[2], s32[3]) parameter(0) + /* Use each tuple element. The verifier will then ensure that if any of + * these get modified, they're replaced with values of the correct shape. */ + a = s32[1] get-tuple-element(param), index=0 + b = s32[2] get-tuple-element(param), index=1 + c = s32[3] get-tuple-element(param), index=2 + ROOT cond = pred[] constant(true) + } + ENTRY Loop { + /* Only `b` should be simplified away. `a` is not a constant within the + * loop, and `c`'s value changes depending on whether we run 0 or 1 + * iterations of the loop. */ + a = s32[1] constant({0}) + b = s32[2] constant({1,1}) + c = s32[3] constant({2,2,2}) + init = (s32[1], s32[2], s32[3]) tuple(a,b,c) + ROOT while = (s32[1], s32[2], s32[3]) while(init), + condition=Cond, body=Body + })"; + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + // Run the tuple simplifier to make the resulting HLO a bit easier to check. + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[1], s32[3])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + m->entry_computation()->root_instruction()->shape(), + ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3])").ValueOrDie())); + EXPECT_THAT(m->entry_computation()->root_instruction(), + op::Tuple(_, op::Constant(), _)); +} + +const char* const kSimpleMergeInductionVariablesModule = R"( + HloModule Test + Body { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + + a = TYPE[] get-tuple-element(param), index=0 + one = TYPE[] constant(1) + a1 = TYPE[] add(a, one) + + b = TYPE[] get-tuple-element(param), index=1 + negone = TYPE[] constant(-1) + b1 = TYPE[] add(b, negone) + + c = TYPE[] add(a, b) + + ROOT tuple = (TYPE[], TYPE[], TYPE[]) tuple(a1,b1,c) + } + Cond { + param = (TYPE[], TYPE[], TYPE[]) parameter(0) + a = TYPE[] get-tuple-element(param), index=0 + b = TYPE[] get-tuple-element(param), index=1 + sum = TYPE[] power(a, b) + ten = TYPE[] constant(10) + ROOT cond = pred[] less-than(sum, ten) + } + ENTRY Loop { + a = TYPE[] constant(10) + b = TYPE[] constant(100) + c = TYPE[] constant(0) + init = (TYPE[], TYPE[], TYPE[]) tuple(a,b,c) + while = (TYPE[], TYPE[], TYPE[]) while(init), condition=Cond, body=Body + + a1 = TYPE[] get-tuple-element(while), index=0 + b1 = TYPE[] get-tuple-element(while), index=1 + ROOT sum = TYPE[] add(a1, b1) + })"; + +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_Simple) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s32"}}); + + auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie()); + // DCE away the old loop so there's just one while loop in the module, making + // it easy to find, and run the tuple simplifier to make the resulting HLO + // easier to check. + EXPECT_TRUE(HloDCE().Run(m.get()).ok()); + EXPECT_TRUE(TupleSimplifier().Run(m.get()).ok()); + + HloInstruction* new_while = FindFirstWhile(m.get()); + // We should have added a new loop counter for s32[] to the end of the tuple. + SCOPED_TRACE(m->ToString()); + Shape new_while_shape = + ShapeUtil::ParseShapeString("(s32[], s32[], s32[], s32[])").ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->root_instruction()->shape(), new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_body()->parameter_instruction(0)->shape(), + new_while_shape)); + EXPECT_TRUE(ShapeUtil::Equal( + new_while->while_condition()->parameter_instruction(0)->shape(), + new_while_shape)); + + EXPECT_THAT(new_while->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(op::Parameter(), 0), + op::GetTupleElement(op::Parameter(), 1), op::Add(), + op::Add(op::GetTupleElement(op::Parameter(), 3), + op::Constant()))); + EXPECT_THAT(new_while->while_condition()->root_instruction(), + op::Lt(op::Power(op::Add(), op::Add()), op::Constant())); +} + +// We shouldn't merge S16 induction variables; we can't create constants of this +// type because S16 literals are not implemented. +TEST_F(WhileLoopSimplifierTest, MergeInductionVariables_SkipS16) { + string hlo_string = absl::StrReplaceAll(kSimpleMergeInductionVariablesModule, + {{"TYPE", "s16"}}); + EXPECT_FALSE( + WhileLoopSimplifier() + .Run(ParseAndReturnVerifiedModule(hlo_string).ValueOrDie().get()) + .ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index f90ac91f9d07aded8cafccf82dae894c9a149bd1..039ccda7322f5efda6a827efbeda1225c3596cc0 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -225,7 +227,8 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { /*static*/ StatusOr WhileUtil::MakeCountedLoop( HloComputation* computation, int32 trip_count, const WhileUtil::LoopStateTy& init_values, - const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + const WhileUtil::LoopBodyGeneratorTy& loop_body_generator, + const OpMetadata& metadata) { CHECK_GE(trip_count, 0); Shape loop_state_shape = MakeLoopStateShape(init_values); @@ -242,6 +245,7 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { computation->AddInstruction(HloInstruction::CreateWhile( loop_state_shape, module->AddEmbeddedComputation(std::move(cond)), module->AddEmbeddedComputation(std::move(body)), init_tuple)); + while_instr->set_metadata(metadata); std::vector result; for (int64 i = 0, e = init_values.size(); i < e; i++) { @@ -268,4 +272,17 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { return result; } +/*static*/ absl::flat_hash_map> +WhileUtil::GetGTEsMapForWhileConditional( + const HloComputation& while_conditional) { + absl::flat_hash_map> result; + for (HloInstruction* user : + while_conditional.parameter_instruction(0)->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + result[user->tuple_index()].push_back(user); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index b1c4486887ae0ddbe2ba4e79f45a265689111017..cba41ccd8b184ba3d867bc170724aee71e777788 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -77,13 +79,21 @@ class WhileUtil { static StatusOr MakeCountedLoop( HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, - const LoopBodyGeneratorTy& loop_body_generator); + const LoopBodyGeneratorTy& loop_body_generator, + const OpMetadata& metadata); // Returns the GetTupleElement instructions in `while_body` that access // elements in the parameter tuple that don't change across iterations. // Assumes `while_body` is the body computation of the while loop in question. static std::vector GetInvariantGTEsForWhileBody( const HloComputation& while_body); + + // Returns a map of index to GetTupleElement instructions in + // `while_conditional` that access elements in the parameter tuple. Assumes + // `while_conditional` is the conditional computation of the while loop in + // question. + static absl::flat_hash_map> + GetGTEsMapForWhileConditional(const HloComputation& while_conditional); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index b9ef18892d7aa859f6b0b505db4c004e4f5c5066..a546a6d39cc55d1f327b8449c7d26cd4c95dbf98 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -45,7 +45,8 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - auto module = CreateNewModule("zero_sized_elimination_test_module"); + auto module = + CreateNewUnverifiedModule("zero_sized_elimination_test_module"); module->AddEntryComputation(builder_.Build()); return ZeroSizedHloElimination{}.Run(module.get()); } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 14c35e7b84f07bebac33a9753ac26a8ee1418f1e..33edbd1b20d01bf132f2a152625d5f49a45f26f9 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,8 +47,11 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Compile(const CompileRequest* arg, + CompileResponse* result) = 0; + + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..746ab9e9977b1b10cdb0cb57197027d65bd50f55 --- /dev/null +++ b/tensorflow/compiler/xla/shape.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +Shape::Shape(const ShapeProto& shape_proto) { + set_element_type(shape_proto.element_type()); + dimensions_.reserve(shape_proto.dimensions_size()); + for (const int64 dimension : shape_proto.dimensions()) { + add_dimensions(dimension); + } + tuple_shapes_.reserve(shape_proto.tuple_shapes_size()); + for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) { + *add_tuple_shapes() = Shape(element_shape); + } + if (shape_proto.has_layout()) { + *mutable_layout() = shape_proto.layout(); + } +} + +ShapeProto Shape::ToProto() const { + ShapeProto proto; + proto.set_element_type(element_type_); + proto.mutable_dimensions()->Reserve(dimensions_size()); + for (const int64 dimension : dimensions()) { + proto.add_dimensions(dimension); + } + proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size()); + for (const Shape& shape : tuple_shapes()) { + *proto.add_tuple_shapes() = shape.ToProto(); + } + if (has_layout()) { + *proto.mutable_layout() = layout(); + } + return proto; +} + +string Shape::ToString(bool print_layout) const { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(*this); + } else { + return ShapeUtil::HumanString(*this); + } +} + +std::ostream& operator<<(std::ostream& out, const Shape& shape) { + out << shape.ToString(/*print_layout=*/true); + return out; +} + +ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) { + for (const ShapeProto& shape_proto : program_shape_proto.parameters()) { + *add_parameters() = Shape(shape_proto); + } + *mutable_result() = Shape(program_shape_proto.result()); + for (const string& name : program_shape_proto.parameter_names()) { + add_parameter_names(name); + } +} + +ProgramShapeProto ProgramShape::ToProto() const { + ProgramShapeProto proto; + for (const Shape& shape : parameters()) { + *proto.add_parameters() = shape.ToProto(); + } + *proto.mutable_result() = result().ToProto(); + for (const string& name : parameter_names()) { + proto.add_parameter_names(name); + } + return proto; +} + +string ProgramShape::ToString() const { + std::vector parameter_strings(parameters_size()); + for (int i = 0; i < parameters_size(); ++i) { + parameter_strings[i] = absl::StrCat( + i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ", + ShapeUtil::HumanString(parameters(i))); + } + return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ", + ShapeUtil::HumanString(result())); +} + +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) { + out << program_shape.ToString() << "\n"; + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..7f6b14ab4286c696dce64d2250a3fe8a57e4865b --- /dev/null +++ b/tensorflow/compiler/xla/shape.h @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// A shape describes the number of dimensions in a array, the bounds of each +// dimension, and the primitive component type. For tuples, shape describes the +// structure (number of elements and nesting). +class Shape { + public: + Shape() = default; + + // Construct a shape from a ShapeProto. + explicit Shape(const ShapeProto& shape_proto); + + // Returns a ShapeProto representation of the Shape. + ShapeProto ToProto() const; + + // Returns a human-readable string that represents the given shape, with or + // without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]". + string ToString(bool print_layout = false) const; + + // The following methods mirror the protobuf generated code interface for the + // message ShapeProto. This enabled easy migration of this data structure + // from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing the primitive type. + PrimitiveType element_type() const { return element_type_; } + void set_element_type(PrimitiveType value) { element_type_ = value; } + + // Methods for accessing the dimensions array. + int dimensions_size() const { return dimensions_.size(); } + int64 dimensions(int index) const { return dimensions_.at(index); } + void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; } + void add_dimensions(int64 value) { dimensions_.push_back(value); } + void clear_dimensions() { dimensions_.clear(); } + const std::vector& dimensions() const { return dimensions_; } + std::vector* mutable_dimensions() { return &dimensions_; } + + // Methods for accessing the tuple subshapes. This field only non-empty for + // tuple shapes. + int tuple_shapes_size() const { return tuple_shapes_.size(); } + const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); } + Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); } + Shape* add_tuple_shapes() { + tuple_shapes_.push_back(Shape()); + return &tuple_shapes_.back(); + } + void clear_tuple_shapes() { tuple_shapes_.clear(); } + const std::vector& tuple_shapes() const { return tuple_shapes_; } + std::vector* mutable_tuple_shapes() { return &tuple_shapes_; } + + // Methods for accessing the layout field. + bool has_layout() const { return layout_.has_value(); } + const Layout& layout() const { + if (layout_.has_value()) { + return *layout_; + } else { + return Layout::default_instance(); + } + } + Layout* mutable_layout() { + if (!layout_.has_value()) { + layout_ = Layout(); + } + return &layout_.value(); + } + void clear_layout() { layout_.reset(); } + + void Swap(Shape* other) { + using std::swap; + swap(*this, *other); + } + + void Clear() { + element_type_ = PRIMITIVE_TYPE_INVALID; + dimensions_.clear(); + tuple_shapes_.clear(); + layout_.reset(); + } + + string SerializeAsString() const { return ToProto().SerializeAsString(); } + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + public: + // The element type of this shape (tuple, array, etc). + PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID; + + // The array bounds of the dimensions. This is nonempty only for array shapes. + std::vector dimensions_; + + // The tuple element subshapes. This is nonempty only for tuple shapes. + std::vector tuple_shapes_; + + // The array layout of the shape. This is present only for array shapes. + absl::optional layout_; +}; + +// Shape of the parameters and output of an XLA computation. This is analogous +// to a traditional function signature. +class ProgramShape { + public: + ProgramShape() = default; + + // Creates a ProgramShape from a ProgramShapeProto protobuf. + explicit ProgramShape(const ProgramShapeProto& program_shape_proto); + + // Returns a proto representation of the object. + ProgramShapeProto ToProto() const; + + string ToString() const; + + // The following methods mirror the protobuf generated code interface for the + // message ProgramShapeProto. This enabled easy migration of this data + // structure from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing and manipulating the Shape of the parameters. + int parameters_size() const { return parameters_.size(); } + const Shape& parameters(int index) const { return parameters_.at(index); } + Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + Shape* add_parameters() { + parameters_.emplace_back(); + return ¶meters_.back(); + } + void clear_parameters() { parameters_.clear(); } + const std::vector& parameters() const { return parameters_; } + std::vector* mutable_parameters() { return ¶meters_; } + + // Methods for accessing and manipulating the Shape of the result. + const Shape& result() const { return result_; } + Shape* mutable_result() { return &result_; } + + // Methods for accessing and manipulating the names of the parameters. + int parameter_names_size() const { return parameter_names_.size(); } + const string& parameter_names(int index) const { + return parameter_names_.at(index); + } + void set_parameter_names(int index, const string& value) { + parameter_names_.at(index) = value; + } + string* mutable_parameter_names(int index) { + return ¶meter_names_.at(index); + } + void add_parameter_names(const string& value) { + parameter_names_.push_back(value); + } + string* add_parameter_names() { + parameter_names_.push_back(""); + return ¶meter_names_.back(); + } + void clear_parameter_names() { parameter_names_.clear(); } + const std::vector& parameter_names() const { + return parameter_names_; + } + std::vector* mutable_parameter_names() { return ¶meter_names_; } + + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + private: + // The shapes of the parameters of the computation represented by this object. + std::vector parameters_; + + // The names of the parameters of the computation represented by this object. + std::vector parameter_names_; + + // The shape of the result of the computation represented by this object. + Shape result_; +}; + +std::ostream& operator<<(std::ostream& out, const Shape& shape); +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e396897eeebc2e7bdc2dc49300c8906710608b05 --- /dev/null +++ b/tensorflow/compiler/xla/shape_test.cc @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ShapeTest : public ::testing::Test { + protected: + const Shape opaque_ = ShapeUtil::MakeOpaqueShape(); + const Shape token_ = ShapeUtil::MakeTokenShape(); + const Shape scalar_ = ShapeUtil::MakeShape(F32, {}); + const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2}); + const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); + const Shape tuple_ = + ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_}); + const Shape nested_tuple_ = + ShapeUtil::MakeTupleShape({tuple_, matrix_, token_}); +}; + +TEST_F(ShapeTest, ShapeToFromProto) { + for (const Shape& shape : + {opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) { + Shape shape_copy(shape.ToProto()); + EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy)) + << shape << " != " << shape_copy; + } +} + +TEST_F(ShapeTest, ShapeToString) { + EXPECT_EQ("opaque[]", opaque_.ToString()); + EXPECT_EQ("token[]", token_.ToString()); + EXPECT_EQ("f32[]", scalar_.ToString()); + EXPECT_EQ("u32[1,2]", matrix_.ToString()); + EXPECT_EQ("s32[3,4]", matrix2_.ToString()); + EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString()); + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + nested_tuple_.ToString()); + + EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true)); + EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true)); + EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true)); + EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true)); + EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", + tuple_.ToString(/*print_layout=*/true)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + nested_tuple_.ToString(/*print_layout=*/true)); +} + +TEST_F(ShapeTest, ProgramShapeToFromProto) { + ProgramShape program_shape; + *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + *program_shape.add_parameters() = ShapeUtil::MakeTokenShape(); + *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {}); + *program_shape.add_parameters() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeShape(F32, {42, 42})}); + + *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7}); + + program_shape.add_parameter_names("foo"); + program_shape.add_parameter_names("bar"); + program_shape.add_parameter_names("baz"); + program_shape.add_parameter_names("qux qux"); + + // Create a copy of the program shape by round-tripping through a proto. + ProgramShape program_shape_copy(program_shape.ToProto()); + ASSERT_EQ(program_shape.parameters_size(), + program_shape_copy.parameters_size()); + for (int i = 0; i < program_shape.parameters_size(); ++i) { + EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i), + program_shape_copy.parameters(i))); + } + + EXPECT_TRUE( + ShapeUtil::Equal(program_shape.result(), program_shape_copy.result())); + + ASSERT_EQ(program_shape.parameter_names_size(), + program_shape_copy.parameter_names_size()); + for (int i = 0; i < program_shape.parameter_names_size(); ++i) { + EXPECT_EQ(program_shape.parameter_names(i), + program_shape_copy.parameter_names(i)); + } +} + +TEST_F(ShapeTest, ProgramShapeToString) { + ProgramShape prog = ShapeUtil::MakeProgramShape( + {opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}, + nested_tuple_); + EXPECT_EQ( + "((unknown): opaque[], " + "(unknown): f32[], " + "(unknown): u32[1,2], " + "(unknown): s32[3,4], " + "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); + + prog.add_parameter_names("arg0"); + prog.add_parameter_names("scalar"); + prog.add_parameter_names("matrix"); + prog.add_parameter_names("matrix2"); + prog.add_parameter_names("tuple"); + prog.add_parameter_names("nested_tuple"); + EXPECT_EQ( + "(arg0: opaque[], " + "scalar: f32[], " + "matrix: u32[1,2], " + "matrix2: s32[3,4], " + "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + prog.ToString()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index df610102b4c7fa08c0b7030124939009130f89f4..7bf97729165bef98fabc29040e02203eee68a53c 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -667,12 +667,11 @@ void ShapeTree::CopySubtreeFrom(const ShapeTree& other, template bool ShapeTree::operator==(const ShapeTree& other) const { bool equal = true; - ForEachElement( - [this, &other, &equal](const ShapeIndex& index, const T& data) { - if (data != other.element(index)) { - equal = false; - } - }); + ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) { + if (data != other.element(index)) { + equal = false; + } + }); return equal; } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c8ff55e7845785d9292516b823fb591cc28cbfad..2b6c484bc4f205be0180403eeac2dd391029b110 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -52,10 +52,10 @@ class ShapeTreeTest : public ::testing::Test { TEST_F(ShapeTreeTest, DefaultConstructor) { ShapeTree int_tree; - EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(int_tree.shape())); ShapeTree bool_tree; - EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(bool_tree.shape())); } void ShapeTreeTest::TestShapeConstructor(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d244923532d8963dcc4a7433b8d353ff5dc483f2..a4d4e1e53e727bdf7822cacaa4559fcae59d4eae 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -74,14 +74,19 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { return out; } -namespace { +bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const { + return size() >= prefix.size() && + indices_.subspan(0, prefix.size()) == prefix.indices_; +} -// Returns whether the given primitive type corresponds to an array shape. -bool IsArrayPrimitiveType(PrimitiveType primitive_type) { +/* static */ bool ShapeUtil::IsArrayPrimitiveType( + PrimitiveType primitive_type) { return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && primitive_type != OPAQUE && primitive_type != TOKEN; } +namespace { + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. @@ -116,14 +121,21 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } - if (!absl::c_equal(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { - VLOG(3) - << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + + const auto& lhs_tiles = lhs.layout().tiles(); + const auto& rhs_tiles = rhs.layout().tiles(); + if (lhs_tiles.size() != rhs_tiles.size()) { return false; } - if (lhs.layout().padding_value() != rhs.layout().padding_value()) { - VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; + for (int64 i = 0; i < lhs_tiles.size(); i++) { + if (!absl::c_equal(lhs_tiles[i].dimensions(), + rhs_tiles[i].dimensions())) { + return false; + } + } + + if (lhs.layout().element_size_in_bits() != + rhs.layout().element_size_in_bits()) { return false; } } @@ -149,7 +161,8 @@ StatusOr MakeShapeWithLayoutInternal( return InvalidArgument("Unsupported element type: %s", PrimitiveType_Name(element_type)); } - Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeUtil::MakeValidatedShape(element_type, dimensions)); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); min2maj->Clear(); for (int64 value : minor_to_major) { @@ -207,7 +220,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ ProgramShape ShapeUtil::MakeProgramShape( std::initializer_list parameters, Shape result) { ProgramShape program_shape; - for (const auto& shape : parameters) { + for (const Shape& shape : parameters) { *program_shape.add_parameters() = shape; } *program_shape.mutable_result() = std::move(result); @@ -216,9 +229,14 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, absl::Span dimensions) { + return MakeValidatedShape(element_type, dimensions).ValueOrDie(); +} + +/* static */ StatusOr ShapeUtil::MakeValidatedShape( + PrimitiveType element_type, absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)); Shape result; - PopulateShape(element_type, dimensions, &result); + TF_RETURN_IF_ERROR(PopulateShape(element_type, dimensions, &result)); return result; } @@ -256,22 +274,22 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeShapeWithDescendingLayout(shape.element_type(), dims); } -/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type, - absl::Span dimensions, - Shape* shape) { +/* static */ Status ShapeUtil::PopulateShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { shape->Clear(); shape->set_element_type(element_type); for (int64 dimension : dimensions) { shape->add_dimensions(dimension); } LayoutUtil::SetToDefaultLayout(shape); - TF_DCHECK_OK(ValidateShape(*shape)); + return ValidateShape(*shape); } /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); - result.mutable_tuple_shapes()->Reserve(shapes.size()); + result.mutable_tuple_shapes()->reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -371,10 +389,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return IsTuple(shape) && TupleElementCount(shape) == 0; } -/* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape); -} - /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { CHECK(IsTuple(shape)) << HumanString(shape); return shape.tuple_shapes_size(); @@ -461,8 +475,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { @@ -569,7 +584,7 @@ namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. StatusOr ParseShapeStringInternal(absl::string_view* s) { - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; @@ -582,7 +597,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - *s = StripLeadingAsciiWhitespace(*s); + *s = absl::StripLeadingAsciiWhitespace(*s); must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); @@ -596,7 +611,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { // we convert in to the RE2-consumable type and then consume the corresponding // amount from our string_view type. static LazyRE2 shape_pattern = { - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; + "^(\\w*\\d*)\\[([\\d,\\s]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,\\s]+)})" + "?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string, &dimensions_string, &format_string, &layout_string)) { @@ -641,7 +657,8 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { result = ShapeUtil::MakeTokenShape(); } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. - result = ShapeUtil::MakeShape(primitive_type, dimensions); + TF_ASSIGN_OR_RETURN( + result, ShapeUtil::MakeValidatedShape(primitive_type, dimensions)); } else if (format_string == "sparse") { TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, @@ -784,6 +801,9 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return byte_size; } else if (shape.element_type() == TOKEN) { return 0; + } else if (shape.element_type() == OPAQUE) { + CHECK_GT(pointer_size, 0); + return pointer_size; } LOG(FATAL) << PrimitiveType_Name(shape.element_type()) << " primitive type has no definitive size"; @@ -806,17 +826,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - absl::Span padded_dimensions = - LayoutUtil::PaddedDimensions(shape); - if (!padded_dimensions.empty()) { - CHECK_EQ(Rank(shape), padded_dimensions.size()); - allocated_element_count = 1; - for (int64 dimension_size : padded_dimensions) { - allocated_element_count *= dimension_size; - } - } else { - allocated_element_count = ElementsIn(shape); - } + allocated_element_count = ElementsIn(shape); } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); @@ -892,8 +902,13 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return Status::OK(); } - int64 shape_size = [&shape]() { - if (LayoutUtil::IsSparseArray(shape)) { + // We can only reason about some aspects of array's shape if it has a valid + // layout, these aspects will be ignored otherwise. + bool shape_has_valid_layout = LayoutUtil::HasLayout(shape) && + LayoutUtil::ValidateLayoutInShape(shape).ok(); + + int64 shape_size = [&]() { + if (shape_has_valid_layout && LayoutUtil::IsSparseArray(shape)) { int64 max_sparse_elements = LayoutUtil::MaxSparseElements(shape.layout()); if (max_sparse_elements < 0) { return max_sparse_elements; @@ -929,7 +944,9 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { return dense_shape_size; } - for (int64 dim : shape.dimensions()) { + absl::Span shape_max_dimensions = + AsInt64Slice(shape.dimensions()); + for (int64 dim : shape_max_dimensions) { dense_shape_size = MultiplyWithoutOverflow(dense_shape_size, dim); if (dense_shape_size < 0) { return dense_shape_size; @@ -951,11 +968,10 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout( const Shape& shape) { - if (LayoutUtil::HasLayout(shape)) { - // Since a layout is present, upgrade to the full set of invariant checks. - return ValidateShape(shape); - } - return ValidateShapeWithOptionalLayoutInternal(shape); + TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape)); + + return LayoutUtil::ValidateLayoutInShape(shape, + /*allow_missing_layouts=*/true); } /* static */ Status ShapeUtil::ValidateShape(const Shape& shape) { @@ -975,7 +991,7 @@ StatusOr ParseShapeStringInternal(absl::string_view* s) { ShapeIndexView index) { const Shape* subshape = &shape; for (auto i : index) { - if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size()) { + if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size() || i < 0) { return false; } subshape = &subshape->tuple_shapes(i); @@ -1152,7 +1168,7 @@ Status ForEachMutableSubshapeHelper( // Let the argument `permutation` be P. This is a permutation over `shape`'s // dimensions, so our return value will be a shape with dims P.I = P. Our // goal is to construct a layout permutation L* that we can apply to P such - // that that the physical dimension ordering of the returned shape is the same + // that the physical dimension ordering of the returned shape is the same // as that of the original shape, namely L'. // // Our returned shape has dims P and layout L*, so its in-memory layout is @@ -1171,13 +1187,6 @@ Status ForEachMutableSubshapeHelper( permutation, AsInt64Slice(shape.layout().minor_to_major()))) { new_layout->add_minor_to_major(index); } - if (shape.layout().padded_dimensions_size() > 0) { - new_layout->clear_padded_dimensions(); - for (auto dim : - Permute(permutation, shape.layout().padded_dimensions())) { - new_layout->add_padded_dimensions(dim); - } - } // The permutation accepted by TransposeIsBitcast is the inverse of the // permutation here. CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) @@ -1280,11 +1289,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - // Padding is not handled. - if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) { - return false; - } - // Check the reshape permutes the positions of each dimension in the // minor-to-major order. positions[i]=k means dimension `i` is k-th minor. // input_positions = apply(dimension_mapping, output_positions) @@ -1316,11 +1320,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - // Padding is not handled. - if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) { - return false; - } - CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape)); if (ElementsIn(input_shape) == 0) { return true; @@ -1603,14 +1602,19 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, Shape output_shape_with_layout = MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), output_layout); - CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); + CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)) + << "reshape is not a bitcast for input_shape: " + << ShapeUtil::HumanStringWithLayout(input_shape) + << " and output_shape_with_layout: " + << ShapeUtil::HumanStringWithLayout(output_shape_with_layout); return output_shape_with_layout; } /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { CHECK(IsArray(shape)); - shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); + shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() + + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); layout->set_format(DENSE); @@ -1644,11 +1648,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } -std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanString(shape); - return out; -} - /*static*/ size_t ShapeUtil::Hash(const Shape& shape) { using tensorflow::hash; using tensorflow::Hash64Combine; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d8bb27beae64bb665c79c2cd7134f613495529cc..84a27f662a57ba274562e2e9be57b7e971c9b477 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -72,7 +74,7 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } - // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + // push_front is O(n), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } using container_type = absl::InlinedVector; @@ -100,6 +102,11 @@ class ShapeIndex { string ToString() const; + template + friend H AbslHashValue(H h, const ShapeIndex& index) { + return H::combine(std::move(h), index.indices_); + } + private: container_type indices_; }; @@ -147,6 +154,9 @@ class ShapeIndexView { string ToString() const; + // Returns true if this shape index starts with 'prefix'. + bool StartsWith(ShapeIndexView prefix) const; + private: absl::Span indices_; }; @@ -312,7 +322,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). @@ -362,6 +375,12 @@ class ShapeUtil { static Shape MakeShape(PrimitiveType element_type, absl::Span dimensions); + // Constructs a new shape with the given element type and sequence of + // dimensions. Method checks if the element type is valid and the shape's + // size fits in std::numeric_limits::max(). + static StatusOr MakeValidatedShape(PrimitiveType element_type, + absl::Span dimensions); + // Creates a Shape with element type corresponding to T and the given // dimensions template @@ -393,8 +412,8 @@ class ShapeUtil { const Shape& shape); // As MakeShape, but the object to write to is passed in. - static void PopulateShape(PrimitiveType element_type, - absl::Span dimensions, Shape* shape); + static Status PopulateShape(PrimitiveType element_type, + absl::Span dimensions, Shape* shape); // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); @@ -449,6 +468,9 @@ class ShapeUtil { // arrays. static bool IsArray(const Shape& shape); + // Returns whether the given primitive type corresponds to an array shape. + static bool IsArrayPrimitiveType(PrimitiveType primitive_type); + // Returns whether the shape is a tuple with at least one element which is // also a tuple. static bool IsNestedTuple(const Shape& shape); @@ -456,9 +478,6 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is the nil shape (an empty tuple). - static bool IsNil(const Shape& shape); - // Returns the number of elements in the given tuple shape. // Precondition: IsTuple(shape) static int64 TupleElementCount(const Shape& shape); @@ -742,10 +761,18 @@ class ShapeUtil { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } + tensorflow::mutex mu; + Status status; // Guarded by mu + while (n < rank) { if (pool != absl::nullopt) { - pool->Schedule( - [indexes, &visitor_function] { visitor_function(indexes); }); + pool->Schedule([indexes, &visitor_function, &mu, &status] { + StatusOr result = visitor_function(indexes); + if (!result.ok()) { + tensorflow::mutex_lock lock(mu); + status = status.ok() ? result.status() : status; + } + }); } else { TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); if (!should_continue) { @@ -763,14 +790,14 @@ class ShapeUtil { } } - return Status::OK(); + // Waits for the scheduled work to complete. + pool.reset(); + return status; } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; -std::ostream& operator<<(std::ostream& out, const Shape& shape); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index c622ecdca1fd66604d1a6ceaf705f2e70edaee55..60bdbe302045e6f3b4bae500c50bc68fb217525d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -345,26 +345,6 @@ TEST(ShapeUtilTest, OpaqueVsArray) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); } -TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { - Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); - shape1.mutable_layout()->add_padded_dimensions(10); - - Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); - shape2.mutable_layout()->add_padded_dimensions(11); - - EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); -} - -TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { - Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); - shape1.mutable_layout()->set_padding_value(ZERO_PAD); - - Shape shape2 = ShapeUtil::MakeShape(F32, {20, 30}); - shape2.mutable_layout()->set_padding_value(LOWEST_PAD); - - EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); -} - TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); ASSERT_TRUE(scalar_default_layout.has_layout()) @@ -395,23 +375,13 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } -TEST(ShapeUtilTest, ByteSizeOfWithPadding) { - EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32)); - Shape shape = ShapeUtil::MakeShape(F32, {10, 20}); - EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape)); - - shape.mutable_layout()->add_padded_dimensions(15); - shape.mutable_layout()->add_padded_dimensions(21); - EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); -} - TEST(ShapeUtilTest, NilShape) { - EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); - EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_TRUE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); - EXPECT_FALSE(ShapeUtil::IsNil( + EXPECT_FALSE(ShapeUtil::IsEmptyTuple( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); } @@ -576,68 +546,6 @@ TEST(ShapeUtilTest, IsLeafIndex) { EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1})); } -TEST(ShapeUtilTest, HumanString) { - Shape opaque = ShapeUtil::MakeOpaqueShape(); - Shape token = ShapeUtil::MakeTokenShape(); - Shape scalar = ShapeUtil::MakeShape(F32, {}); - Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); - Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); - Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); - EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); - EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); - EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); - EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", - ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(nested_tuple)); - - EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); - EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar)); - EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix)); - EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); - EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", - ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ( - "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " - "token[])", - ShapeUtil::HumanStringWithLayout(nested_tuple)); - - ProgramShape prog = ShapeUtil::MakeProgramShape( - {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); - EXPECT_EQ( - "((unknown): opaque[], " - "(unknown): f32[], " - "(unknown): u32[1,2], " - "(unknown): s32[3,4], " - "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); - - prog.add_parameter_names("arg0"); - prog.add_parameter_names("scalar"); - prog.add_parameter_names("matrix"); - prog.add_parameter_names("matrix2"); - prog.add_parameter_names("tuple"); - prog.add_parameter_names("nested_tuple"); - EXPECT_EQ( - "(arg0: opaque[], " - "scalar: f32[], " - "matrix: u32[1,2], " - "matrix2: s32[3,4], " - "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " - "token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); -} - TEST(ShapeUtilTest, ForEachSubshapeArray) { const Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); int calls = 0; diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index 1c135dda864b3060b8bdc6369f18268d7c5c7f9e..a40bb7875e7ea53a8959a9a67ec09ec260ba9c37 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -29,7 +29,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, CHECK_GT(rank_, 0); CHECK_EQ(indices_.size() % rank_, 0) << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; - CHECK_LT(index_count(), max_indices_); + CHECK_LE(index_count(), max_indices_); } SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 8a0ae330420531b833ed670118e6b6b1056bd358..5a7a4faa7e89b27fb537f20d94c21cb4a76e000d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,7 +44,7 @@ cc_library( testonly = True, srcs = ["xla_internal_test_main.cc"], deps = [ - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/strings", @@ -79,6 +79,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -117,12 +118,12 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", @@ -135,50 +136,13 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) -cc_library( - name = "hlo_verified_test_base", - testonly = True, - srcs = ["hlo_verified_test_base.cc"], - hdrs = ["hlo_verified_test_base.h"], - deps = [ - ":hlo_test_base", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "hlo_verified_test_base_test", - srcs = ["hlo_verified_test_base_test.cc"], - deps = [ - ":hlo_test_base", - ":hlo_verified_test_base", - ":test_macros_cpu", - ":test_utils", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -335,6 +299,52 @@ xla_test( ], ) +xla_test( + name = "conv_depthwise_test", + timeout = "long", + srcs = ["conv_depthwise_test.cc"], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + +xla_test( + name = "grouped_convolution_test", + timeout = "long", + srcs = ["grouped_convolution_test.cc"], + blacklisted_backends = [ + # disabled because of a break b/119590850. + "gpu", + # disabled because it times out. + "cpu", + ], + shard_count = 50, + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], @@ -476,7 +486,9 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, - tags = ["optonly"], + tags = [ + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", @@ -658,6 +670,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) @@ -682,6 +695,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/base", ], ) @@ -705,6 +719,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", ], ) @@ -863,7 +878,8 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - shard_count = 25, + shard_count = 40, + tags = ["optonly"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1172,6 +1188,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1296,6 +1313,7 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1614,6 +1632,7 @@ xla_test( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", ], ) @@ -1858,6 +1877,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", ], ) @@ -1894,6 +1914,7 @@ xla_test( xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], + backends = ["gpu"], deps = [ "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2150,6 +2171,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c257566fb218d4769aec0c793efb9256b023b7ea..915b456b52215f8d6a9eb6c5b933f3502f1d3d2c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -139,7 +139,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { } // A non-canonical quiet NaN value. -static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); +static const float kNonCanonicalNaN = absl::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { XlaBuilder builder(TestName()); @@ -329,13 +329,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); - auto b_param = ConstantR1(&builder, b_values); + auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param"); + auto b_constant = ConstantR1(&builder, b_values); - auto sum1 = Add(a_constant, b_constant); - auto sum2 = Add(a_constant, b_param); - auto sum3 = Add(a_param, b_constant); - auto sum4 = Add(a_param, b_param); + auto sum1 = Add(a_constant, b_param); + auto sum2 = Add(a_constant, b_constant); + auto sum3 = Add(a_param, b_param); + auto sum4 = Add(a_param, b_constant); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); @@ -350,6 +350,44 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { error_spec_); } +// TODO(b/119692968): This test runs OOM on the GPU and CPU backend. +XLA_TEST_F(ArrayElementwiseOpTest, + DISABLED_ON_GPU(DISABLED_ON_CPU(DeeplyNestedAddWithSlices))) { + XlaBuilder builder(TestName()); + std::vector values(30, 0.0); + auto a_literal = LiteralUtil::CreateR1(values); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b_literal = LiteralUtil::CreateR1(values); + auto b = Parameter(&builder, 1, b_literal.shape(), "x"); + + // Construct a sequence of diamond-shaped gadgets like this: + // + // add + // / \ + // slice slice + // \ / + // add + // + // Each 'left' slice removes the last element, each 'right' slice removes the + // first element. In this way, we index into the add with different + // multi-dimensional index arrays, which defeats the caching we use to avoid + // exponential compile time. + std::function generate_recursive = + [&](int64 slice_size) -> XlaOp { + if (slice_size == values.size()) { + return Add(a, b); + } + XlaOp param = generate_recursive(slice_size + 1); + auto slice1 = Slice(param, {0}, {slice_size}, {1}); + auto slice2 = Slice(param, {1}, {slice_size + 1}, {1}); + return Add(slice1, slice2); + }; + generate_recursive(1); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, {0.0}, {a_data.get(), b_data.get()}); +} + XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); @@ -2478,8 +2516,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { - { 00 }, - { 01 } + { 0, 0 }, + { 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2492,8 +2530,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1100 }, - { 0001 } + { 1, 1, 0, 0 }, + { 0, 0, 0, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2506,8 +2544,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0100 }, - { 0000 } + { 0, 1, 0, 0 }, + { 0, 0, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2520,8 +2558,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 1011 }, - { 1111 } + { 1, 0, 1, 1 }, + { 1, 1, 1, 1 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2534,8 +2572,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { - { 0011 }, - { 1110 } + { 0, 0, 1, 1 }, + { 1, 1, 1, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2744,12 +2782,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { -{ { 01 }, - { 00 }, - { 00 } }, -{ { 01 }, - { 10 }, - { 01 } } +{ + { 0, 1 }, + { 0, 0 }, + { 0, 0 } +}, +{ + { 0, 1 }, + { 1, 0 }, + { 0, 1 } +} })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index dde19fb65d65064c9452a6ac49c70e20cf113336..702fb32adfc8a0ded26845c92245776a79777c34 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -161,8 +161,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {1}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -175,8 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {2, 2}), {0}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {2, 2}, {0}); Array2D expected(2, 2); expected(0, 0) = 1; @@ -189,8 +187,8 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 1}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -207,8 +205,8 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), - ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + BroadcastInDim(ConstantR2(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2}, + {0, 2}); Array3D expected(2, 2, 2); expected(0, 0, 0) = 1.0; @@ -225,8 +223,7 @@ XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { XlaBuilder b(TestName()); - BroadcastInDim(ConstantR1(&b, {1, 2}), - ShapeUtil::MakeShape(F32, {3, 2}), {1}); + BroadcastInDim(ConstantR1(&b, {1, 2}), {3, 2}, {1}); Array2D expected(3, 2); expected(0, 0) = 1; diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 9966e4606ef7f104487182e0240e64e4c9e4d834..9930bfc95c297093584d427397cac042c296050f 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -42,7 +42,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { ShapeUtil::MakeShape(F32, {}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -58,7 +58,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -81,7 +81,7 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { builder.AddInstruction(HloInstruction::CreateTuple({element1, element2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -102,7 +102,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -121,7 +121,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -138,7 +138,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -158,7 +158,7 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -183,7 +183,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -214,7 +214,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -230,7 +230,7 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); LOG(INFO) << hlo_module->ToString(); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -253,7 +253,7 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -287,7 +287,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); // Create HLO module, compile, and execute. - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index fbdf0fcb6543f09dedefef55cfe0f8a5d9067d5a..12c029983336cc9aed0fde4ce6881c9a00a9869e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -107,7 +107,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransfer( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } return client_->ExecuteAndTransfer(computation, arguments, &execution_options); @@ -127,7 +127,7 @@ StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; + shape_with_output_layout->ToProto(); } execution_options.clear_device_handles(); return ref_client_->ExecuteAndTransfer(computation, arguments, @@ -262,6 +262,28 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } +StatusOr ClientLibraryTestBase::ComputeAndTransfer( + XlaBuilder* builder, absl::Span arguments_passed_in, + const Shape* shape_with_layout) { + std::vector arguments(arguments_passed_in.begin(), + arguments_passed_in.end()); + + // Transfer and use elements of arguments_, if the AddParam() API was used. + std::vector> owning_arguments; + if (!arguments_.empty()) { + CHECK(arguments.empty()); + for (const auto& argument : arguments_) { + owning_arguments.push_back( + client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)) + .ValueOrDie()); + arguments.push_back(owning_arguments.back().get()); + } + } + + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_layout); +} + Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, absl::Span arguments_passed_in, diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 9d32f4f5174a57a53a9d3e6477b46fa4de852f7f..65a23dd883594b9bf9c37494a37e9be39b197788 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,7 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); - opts->set_xla_gpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_min_max(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } @@ -188,6 +188,13 @@ class ClientLibraryTestBase : public ::testing::Test { ErrorSpec error, const Shape* shape_with_layout = nullptr); + // Build and run the computation and return the result as a literal. + // shape_with_layout indicates the result layout to request when calling + // Execute. + StatusOr ComputeAndTransfer( + XlaBuilder* builder, absl::Span arguments, + const Shape* shape_with_layout = nullptr); + // ComputeAndCompare variant which returns an error status. Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 6f2ca84bb646e88af221ab80b727911ff7d990eb..363dee74b2755a6bdc3c5a5164a85378581c21d2 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -50,7 +50,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - execute_layout); + execute_layout) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); @@ -84,7 +85,8 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, - /*minor_to_major=*/{1, 0})}); + /*minor_to_major=*/{1, 0})}) + .ToProto(); TF_ASSERT_OK_AND_ASSIGN( auto result, diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 022641394f113ef28e7c53058385d77572822213..fbebe0408730f2fb37aa57a0f19291bbaa3826f9 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -32,11 +32,10 @@ StatusOr> CodegenTestBase::CompileToAotCompilationResult( std::unique_ptr hlo_module, const AotCompilationOptions& options) { - std::vector> hlo_modules; - hlo_modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique(std::move(hlo_module)); TF_ASSIGN_OR_RETURN( std::vector> results, - backend().compiler()->CompileAheadOfTime(std::move(hlo_modules), + backend().compiler()->CompileAheadOfTime(std::move(module_group), options)); return std::move(results.front()); } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9811a015e91d866d6f4de6ebb6dac536ed6c7e06..4f5b525a34252db9e967a55af0d1bf39a2dd830e 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -492,6 +492,32 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } +XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { + XlaBuilder builder(TestName()); + auto a_literal = LiteralUtil::CreateR1({256.0}); + auto a = Parameter(&builder, 0, a_literal.shape(), "x"); + auto b = ConcatInDim(&builder, {a, a}, 0); + auto c = ConcatInDim(&builder, {b, b}, 0); + auto d = ConcatInDim(&builder, {c, c}, 0); + auto e = ConcatInDim(&builder, {d, d}, 0); + auto f = ConcatInDim(&builder, {e, e}, 0); + auto g = ConcatInDim(&builder, {f, f}, 0); + auto h = ConcatInDim(&builder, {g, g}, 0); + auto i = ConcatInDim(&builder, {h, h}, 0); + auto j = ConcatInDim(&builder, {i, i}, 0); + auto k = ConcatInDim(&builder, {j, j}, 0); + auto l = ConcatInDim(&builder, {k, k}, 0); + auto m = ConcatInDim(&builder, {l, l}, 0); + auto n = ConcatInDim(&builder, {m, m}, 0); + auto o = ConcatInDim(&builder, {n, n}, 0); + auto p = ConcatInDim(&builder, {o, o}, 0); + auto q = ConcatInDim(&builder, {p, p}, 0); + ConcatInDim(&builder, {q, q}, 0); + std::vector expected(131072, 256.0); + auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); + ComputeAndCompareR1(&builder, expected, {a_data.get()}); +} + // Describes a binary rank-2 concatenation test. struct R2BinarySpec { int64 lhs_dim0; diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..627a17a0ca114085240dbaf28211bb3511cf0cab --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct DepthwiseConvolution2DSpec { + int64 output_feature, window, stride, pad, lhs_dilate; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class DepthwiseConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + std::vector> config_options = { + {128, 6, 3, 64}, {256, 5, 3, 256}, {256, 5, 2, 144}, {144, 5, 3, 64}, + {144, 5, 2, 256}, {8, 48, 17, 8}, {128, 20, 6, 64}, {64, 14, 12, 172}, + {16, 9, 4, 16}, {128, 1, 2, 144}, {256, 1, 2, 64}}; + + for (auto option : config_options) { + int64 feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + + std::vector kernel_layout = {3, 2, 1, 0}; + DepthwiseConvolution2DSpec config; + config.output_feature = feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, 1, feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, feature}; + } else if (feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = feature / 32; + config.output_dims = {batch, feature / 32, + activation_size - kernel_size + 1, feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.output_feature); + + } else if (spec.stride == -1) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.output_feature); + } else { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); + } +} + +XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { + const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = + BuildHloTextDepthwiseConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + DepthwiseConvolution2DTestWithRandomIndices, DepthwiseConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + DepthwiseConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 5f063e67847487f1d18bf4ee80b1634ebdf4183a..20bf3c317986c30c12dca7dca14dbf80c70b42f6 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -429,11 +429,9 @@ TEST_F(ConvertTest, ConvertReshape) { std::vector GetInterestingF16ConversionTestCases() { float infinity = std::numeric_limits::infinity(); - float half_min_positive_normal = - tensorflow::bit_cast(0x38800000); - float half_max_subnormal = tensorflow::bit_cast(0x387fc000); - float half_min_positive_subnormal = - tensorflow::bit_cast(0x33800000); + float half_min_positive_normal = absl::bit_cast(0x38800000); + float half_max_subnormal = absl::bit_cast(0x387fc000); + float half_min_positive_subnormal = absl::bit_cast(0x33800000); float half_max = 65504.0f; std::vector test_cases( diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 070b092d18930027e215cb43ff917e36cac99f12..4a58a1ed66c438d1dd9561f4eb029b38d8c6cbdd 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D(&builder, *alhs); auto rhs = ConstantR4FromArray4D(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } @@ -590,7 +597,692 @@ TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { } template -class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + + auto filter_r = filter_r1.Reshape(filter_dims); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x512_3x3x1x512_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Input_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {256, 4, 4, 512}; + std::vector filter_dims = {3, 3, 1, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/512); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(2048 * 256, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = + expected_r1.Reshape({256, 2, 2, 512}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_256x4x4x512_3x3x1x512_Depthwise_Both_Batch_in_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 5}; + std::vector filter_dims = {3, 3, 1, 5}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(6864), static_cast(7296), static_cast(7746), + static_cast(8214), static_cast(8700), static_cast(7809), + static_cast(8286), static_cast(8781), static_cast(9294), + static_cast(9825), static_cast(10644), static_cast(11256), + static_cast(11886), static_cast(12534), static_cast(13200), + static_cast(11589), static_cast(12246), static_cast(12921), + static_cast(13614), static_cast(14325)}); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 5}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE( + Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x5_3x3x1x5_Depthwise_Valid_Output_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({3, 0, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Depthwise_Input_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 160}; + std::vector filter_dims = {3, 3, 1, 160}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/160); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4_relaid = + input_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(640, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 160}).ConsumeValueOrDie(); + auto expected_r4_relaid = + expected_r4.Relayout(LayoutUtil::MakeLayout({0, 3, 2, 1})); + + auto input_literal = + client_->TransferToServer(input_r4_relaid).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4_relaid, + {input_literal.get(), filter_literal.get()}, + error_spec_, &expected_r4_relaid.shape()); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + TestTypes); +TYPED_TEST(Convolve2D_1x4x4x160_3x3x1x160_Dephtwise_Both_Batch_In_Lanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 4, 4, 1024}; + std::vector filter_dims = {3, 3, 1, 1024}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/1024); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(4096, static_cast(18)); + + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 2, 2, 1024}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x4x4x1024_3x3x1x1024_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -618,7 +1310,200 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { dnums.set_kernel_output_feature_dimension(3); ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, - /*feature_group_count=*/3); + /*feature_group_count=*/3); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(5076), static_cast(5160), static_cast(5244), + static_cast(5328), static_cast(6164), static_cast(6264), + static_cast(6364), static_cast(6464), static_cast(7380), + static_cast(7496), static_cast(7612), static_cast(7728)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 512}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/8); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(512, static_cast(1024)); + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 1024}; + std::vector filter_dims = {2, 2, 128, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/8); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1)); + + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(2)); + + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector output_elems(8, static_cast(1024)); + auto expected_r1 = LiteralUtil::CreateR1(output_elems); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 8}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x8_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 3, 4}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); @@ -631,12 +1516,140 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { auto filter_r1 = LiteralUtil::CreateR1(filter_elems); auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(7712), static_cast(8816), + static_cast(9992), static_cast(11240)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes + : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 12}; + std::vector filter_dims = {2, 2, 4, 3}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(3); + dnums.set_kernel_output_feature_dimension(2); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4_relaid = + filter_r4.Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); auto expected_r1 = LiteralUtil::CreateR1( - {static_cast(5076), static_cast(5160), static_cast(5244), - static_cast(5328), static_cast(6164), static_cast(6264), - static_cast(6364), static_cast(6464), static_cast(7380), - static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + {static_cast(6968), static_cast(8516), static_cast(10280), + static_cast(12260)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r4_relaid).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + TestTypes); +TYPED_TEST(Convolve2D_1x2x2x12_2x2x3x4_Grouped_Valid_Filter_OF_In_Sublanes, + Types) { + this->RunTest(); +} + +template +class Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 1, 1, 12}; + std::vector filter_dims = {1, 1, 3, 4}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/4); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = + LiteralUtil::CreateR1({static_cast(38), static_cast(98), + static_cast(176), static_cast(272)}); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 4}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(input_r4).ConsumeValueOrDie(); @@ -649,8 +1662,8 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { +TYPED_TEST_CASE(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x1x1x12_1x1x3x4_Grouped_Valid, Types) { this->RunTest(); } @@ -876,7 +1889,7 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { // (We run this test on all platforms, because, what the heck.) XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( - "cudnn-convolution-algorithm-picker"); + "cudnn-conv-algorithm-picker"); XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); @@ -944,6 +1957,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF32ForwardReversed)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f32[3,56,56,16] parameter(0) + %arg1 = f32[3,3,3,32] parameter(1) + ROOT %conv = f32[54,54,16,32] convolution(%arg0, %arg1), window={size=3x3 rhs_reversal=1x1}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { constexpr char kHlo[] = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 1407e68d9a336b6bb1c960711015430f872aa912..3622f2c1e84639baed13059b21b20609d1347da6 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -45,7 +45,7 @@ class CopyOpTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -98,7 +98,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {&literal}); @@ -119,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, @@ -143,7 +143,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -175,7 +175,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); Literal result = ExecuteAndTransfer(std::move(module), {}); @@ -209,7 +209,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, std::unique_ptr computation = builder.Build(); - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); Literal result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index a693fa35954bcb2d95074c94d0aa3eabc1d5fd62..738b6442354b01364278e3e3c713aa2cdb5cf47d 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -70,7 +70,7 @@ class CustomCallTest : public HloTestBase { }; XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( @@ -85,7 +85,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { - auto module = CreateNewModule(); + auto module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); Array2D array(2, 2); @@ -105,9 +105,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { - auto module = CreateNewModule(); +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { + auto module = CreateNewUnverifiedModule(); auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( @@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewUnverifiedModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewUnverifiedModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6c0847a875798870b4362a99ac2ab65d99f9f3e6..c5d8b663f4abe77e05ec213d2e4e075c260a8655 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" namespace xla { namespace { @@ -637,6 +636,76 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } +#ifndef XLA_TEST_BACKEND_CPU +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = + Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + auto y_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} + +// TODO(b/74459949): failed on CPU on 2018-10-29. +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType({2, 2}), "x"); + auto y = + Parameter(&builder, 1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(1); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + DotGeneral(x, y, dnums); + + auto x_data = this->client_ + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( + {{1.0f, 0.0f}, {0.0f, 1.0f}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR2( + &builder, + /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()}, + this->error_spec_); +} +#endif // XLA_TEST_BACKEND_CPU + XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { using T = TypeParam; diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 51b50d456e496c9c01c38fb8539bb3737de16937..c84973e17b234c24c84f02a369ce0185f5772cca 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/base/casts.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/casts.h" namespace xla { namespace { @@ -47,7 +47,7 @@ class ExhaustiveF32ElementwiseOpTest // input to 0 under the assumption that the op is at least correct on 0. input_literal.Set({i - begin}, 0.0f); } else { - input_literal.Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, absl::bit_cast(i)); } } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 4d4b676a538947c8dd92a7e34db72e45766cae2c..d1fddf9d6b494a822610e41307fa103dc90bdef3 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -81,7 +81,7 @@ class FusionTest : public HloTestBase { } auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto prim_type = primitive_util::NativeToPrimitiveType(); @@ -183,7 +183,7 @@ XLA_TEST_F(FusionTest, Test) { // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}), // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}} auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -231,7 +231,7 @@ XLA_TEST_F(FusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -266,7 +266,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); // Build simple fusion computation: y = x^2 (elementwise). auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto two = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); @@ -290,7 +290,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( @@ -314,7 +314,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto single_element_array = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( @@ -329,7 +329,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -344,7 +344,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( @@ -359,7 +359,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( @@ -374,7 +374,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( @@ -389,7 +389,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); auto reshape1 = builder.AddInstruction( @@ -404,7 +404,7 @@ XLA_TEST_F(FusionTest, Reshape__) { XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( @@ -419,7 +419,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -434,7 +434,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -449,7 +449,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -465,7 +465,7 @@ XLA_TEST_F(FusionTest, Reverse) { XLA_TEST_F(FusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( @@ -483,7 +483,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { XLA_TEST_F(FusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -501,7 +501,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { XLA_TEST_F(FusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice( @@ -519,7 +519,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { XLA_TEST_F(FusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto const1 = builder.AddInstruction( @@ -541,7 +541,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { XLA_TEST_F(FusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1, 2, 3, 4}))); auto reshape1 = builder.AddInstruction( @@ -559,7 +559,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { XLA_TEST_F(FusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{1, 2}, {3, 4}}))); auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -587,7 +587,7 @@ std::unique_ptr MakeReduceTestComputation() { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -607,7 +607,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -630,7 +630,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( @@ -682,7 +682,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // into a fusion, it should remain shared, rather than being duplicated // within the fusion. XLA_TEST_F(FusionTest, SharedConstant) { - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f7049910e70c4e591636a47c1b6ba72cf2c234f --- /dev/null +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +struct GroupedConvolution2DSpec { + int64 input_feature, output_feature, window, stride, pad, lhs_dilate; + int64 group_size, group_count; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +class GroupedConvolution2DTest + : public HloTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> {}; + +static std::vector GetConv2DTestCases() { + std::vector config_set; + // Add to this set if you want a new test configuration. + // Rule : the penultimate number must be divisible by the last number. + std::vector> config_options = {{8, 2, 2, 1, 1024, 128}, + {512, 3, 3, 144, 1024, 16}, + {256, 3, 3, 129, 512, 64}, + {64, 1, 2, 127, 32, 8}, + {256, 3, 3, 256, 1024, 4}}; + + for (auto option : config_options) { + int64 output_feature = option[0]; + int64 activation_size = option[1]; + int64 kernel_size = option[2]; + int64 batch = option[3]; + int64 input_feature = option[4]; + int64 group_size = option[5]; + + std::vector kernel_layout = {3, 2, 1, 0}; + GroupedConvolution2DSpec config; + config.group_size = group_size; + config.group_count = input_feature / group_size; + config.output_feature = output_feature; + config.window = kernel_size; + + config.activation_dims = {batch, activation_size, activation_size, + input_feature}; + config.activation_layout = {3, 0, 2, 1}; + + config.kernel_dims = {kernel_size, kernel_size, group_size, output_feature}; + config.kernel_layout = {3, 2, 1, 0}; + + if (activation_size == 1 && kernel_size == 2) { + // Test for outer dim. + config.output_dims = {batch, activation_size + kernel_size - 1, + activation_size + kernel_size, output_feature}; + } else if (output_feature == 256) { + // Restrict dilation-based tests only to one feature configuration. + config.stride = activation_size - 1; + config.pad = 0; + config.lhs_dilate = output_feature / 32; + config.output_dims = {batch, output_feature / 32, + activation_size - kernel_size + 1, output_feature}; + } else { + config.stride = config.pad = config.lhs_dilate = -1; + config.output_dims = {batch, activation_size - kernel_size + 1, + activation_size - kernel_size + 1, output_feature}; + } + + // Try this layout for all kernel shapes. + config.output_layout = {3, 0, 2, 1}; + config_set.push_back(config); + + // Try other layouts only for certain kernel shapes. + if (kernel_size % 2 == 0) { + config.activation_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.output_layout = {0, 3, 2, 1}; + config_set.push_back(config); + + config.activation_layout = {3, 0, 2, 1}; + config_set.push_back(config); + } + } + + return config_set; +} + +string GroupedConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextGroupedConvolution2D(const GroupedConvolution2DSpec& spec, + bool use_bfloat16) { + const string data_type = GetFloatDataType(use_bfloat16); + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + // Check for outer dim. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.group_count); + + } else if (spec.stride == -1) { + // Check for basic, non-dilated cases. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.group_count); + } else { + // Check for base dilations. + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.group_count); + } +} + +XLA_TEST_P(GroupedConvolution2DTest, DoIt) { + const GroupedConvolution2DSpec& spec = ::testing::get<0>(GetParam()); + bool use_bfloat16 = ::testing::get<1>(GetParam()); + const string hlo_text = BuildHloTextGroupedConvolution2D(spec, use_bfloat16); + + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0.01, 0.01}, + [](HloModule* module) -> Status { + BFloat16MixedPrecisionRemoval remover; + TF_RETURN_IF_ERROR(remover.Run(module).status()); + Despecializer despecializer; + return despecializer.Run(module).status(); + })); +} + +INSTANTIATE_TEST_CASE_P( + GroupedConvolution2DTestWithRandomIndices, GroupedConvolution2DTest, + ::testing::Combine(::testing::ValuesIn(GetConv2DTestCases()), + ::testing::Bool()), + GroupedConvolution2DTestDataToString); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39..989a7c705a8254f99e5cc0e97dfde5942f146964 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -85,26 +85,74 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; + LOG(ERROR) << "Contents of bad module:"; + XLA_LOG_LINES(tensorflow::ERROR, ToString()); + } +} + HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : test_runner_(test_platform), reference_runner_(reference_platform) { + bool allow_mixed_precision_in_hlo_verifier, + std::function + instruction_can_change_layout_func) + : test_runner_(test_platform), + reference_runner_(reference_platform), + verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier) { hlo_verifier_ = absl::make_unique( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } -std::unique_ptr HloTestBase::CreateNewModule(const string& name) { +std::unique_ptr HloTestBase::CreateNewUnverifiedModule( + const string& name) { return absl::make_unique(name, GetModuleConfigForTest()); } +std::unique_ptr HloTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + +StatusOr> +HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = absl::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + /* static */ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, HloModule* module) { @@ -129,7 +177,7 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { } DebugOptions HloTestBase::GetDebugOptionsForTest() { - auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); debug_options.set_xla_gpu_max_kernel_unroll_factor(1); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104d656946d45008adec9ea3960984545..1d1e7f437296a7493ef7da07039fcf6d273f35bc 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -38,6 +39,31 @@ limitations under the License. namespace xla { +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + // A base class for tests which build and/or run HLO code. The class includes // support for running an HLO module on two platforms and compare the results. // This is a lower level of abstraction than using the client interface and @@ -72,7 +98,22 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - std::unique_ptr CreateNewModule(const string& name = TestName()); + // + // This returns a vanilla HloModule that doesn't run the HLO verifier on + // destruction. + ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") + std::unique_ptr CreateNewUnverifiedModule( + const string& name = TestName()); + + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const string& name = TestName()); + + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Runs the hlo_pass with the provided module and returns the result. This // function also verifies that the module remains unchanged when hlo_pass @@ -88,14 +129,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function + instruction_can_change_layout_func = {}); ~HloTestBase() override {} @@ -243,6 +288,8 @@ class HloTestBase : public ::testing::Test { HloRunner test_runner_; HloRunner reference_runner_; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc deleted file mode 100644 index 8bd0a729b77f3ec14204952cb0062103c823883e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); -} - -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { - Status status = Verify(); - if (!status.ok()) { - ADD_FAILURE() << "HloVerifier failed on module " << name() - << (message.empty() ? "" : absl::StrCat(" (", message, ")")) - << ": " << status; - } -} - -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), - verifier_layout_sensitive_(layout_sensitive), - allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} - -HloModule& HloVerifiedTestBase::module() { - if (!module_) { - module_ = CreateNewVerifiedModule(TestName()); - } - return *module_; -} - -HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(CreateNewVerifiedModule(name)); - return modules_.back().get(); -} - -void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - CHECK(!module_) << "Called ParseModule when test already has a module."; - module_ = CreateNewVerifiedModule(TestName()); - TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); - module_->VerifyOrAddFailure("after parsing"); -} - -StatusOr> -HloVerifiedTestBase::ParseAndReturnVerifiedModule( - absl::string_view hlo_text, const HloModuleConfig& config) { - auto module = CreateNewVerifiedModule(TestName()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); - return std::move(module); -} - -std::unique_ptr HloVerifiedTestBase::CreateNewVerifiedModule( - const string& name) { - return absl::make_unique( - name, GetModuleConfigForTest(), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h deleted file mode 100644 index 388a99bb36408665edbc20ade6c6a733d64db88d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ -#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ - -#include -#include -#include - -#include "absl/base/macros.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Verifies the module using HloVerifier and returns the status. - Status Verify(); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(const string& message); - - private: - HloVerifier verifier_; -}; - -// A base class for HLO tests that stores a default VerifiedHloModule. -class HloVerifiedTestBase : public HloTestBase { - protected: - HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - - // Constructs a default shape verifier. - std::unique_ptr MakeShapeVerifier(); - - // Returns the default HloModule, lazily creating it if necessary via - // HloTestBase::CreateNewModule(). - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule& module(); - - ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") - void ParseAndVerifyModule(absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Parses the given string and returns module as a VerifiedHloModule. - StatusOr> ParseAndReturnVerifiedModule( - absl::string_view hlo_text, - const HloModuleConfig& config = HloModuleConfig()); - - // Creates a new module for a test, and stores it in modules_ so it can be - // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent - // creation of unverified modules. - ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") - HloModule* CreateNewModule(const string& name = TestName()); - - // Creates and returns a verified HLO module with the given name. - std::unique_ptr CreateNewVerifiedModule( - const string& name = TestName()); - - private: - // It is confusing to store modules created by module() and CreateNewModule() - // in different fields, but it allows us to migrate tests to - // HloVerifiedTestBase more easily, so it's a win because we can verify more - // modules. See b/80488902. - // - // Lazily populated. Access via module(). - std::unique_ptr module_; - - // Populated by calls to CreateNewModule. - std::vector> modules_; - - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc deleted file mode 100644 index 5c0263e811f94c90a69a460525ffa0c65127ebb5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" - -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -// This class includes unit tests which are expected to fail because invalid HLO -// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to -// include the necessary gunit parts to test this test machinery (needs the -// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the -// disabled tests enabled and failures can be manually compared against -// expectations. -class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; - -XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { - // Test shouldn't fail if no module is created at all. -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { - // Use module() to lazily create an empty module, build it up, and verify no - // failures. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { - // Use module() to lazily create an empty module and build up an invalid - // module. - HloModule& hlo_module = module(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - hlo_module.AddEntryComputation(builder.Build()); - - *hlo_module.entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { - // Call CreateNewModule and build up a valid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { - // Call CreateNewModule and build up a invalid module. - HloModule* module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); - module->AddEntryComputation(builder.Build()); - - *module->entry_computation()->root_instruction()->mutable_shape() = - ShapeUtil::MakeShape(PRED, {1, 2, 3}); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndVerifyModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - ParseAndVerifyModule(hlo_string); - EXPECT_EQ(module().entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_EQ(module->entry_computation()->instruction_count(), 3); -} - -XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleGood - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[] add(x,y) -} - -RANDOM GARBAGE -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -// This test is expected to fail. See test class comment. -XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { - const char* const hlo_string = R"( -HloModule ParseAndReturnVerifiedModuleBad - -ENTRY entry { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT add = f32[1234] add(x,y) -} -)"; - - ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 310f3495922250d68aa463fcbb24ef0b04603d09..65205f53ddc582ae477d67705f161fef1e31b857 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -113,5 +113,26 @@ INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, /*step=*/10), ::testing::Values(0, 1, 2))); +class IotaR3PredTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(IotaR3PredTest, DoIt) { + const auto element_type = PRED; + const int64 num_elements = 2; + const int64 iota_dim = GetParam(); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3PredTestInstantiation, IotaR3PredTest, + ::testing::Values(0, 1, 2)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 8d658695576035cdc34a213847460dd80de5f67e..a78ccacec114858740bf1b9c04e9b688bca5818d 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -68,7 +68,7 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); compiler->SetPreOptimizationHook(pre_opt_hook); @@ -90,18 +90,19 @@ class LLVMCompilerTest : public ::testing::Test { builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - std::unique_ptr hlo_module = CreateNewModule(); + std::unique_ptr hlo_module = CreateNewUnverifiedModule(); hlo_module->AddEntryComputation(builder.Build()); - std::vector> modules; - modules.push_back(hlo_module->Clone()); - modules.push_back(std::move(hlo_module)); + auto module_group = absl::make_unique("test_module_group"); + module_group->push_back(hlo_module->Clone()); + module_group->push_back(std::move(hlo_module)); std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + EXPECT_IS_OK(compiler->Compile(std::move(module_group), + std::move(executors), /*device_allocator=*/nullptr)); } @@ -123,9 +124,9 @@ class LLVMCompilerTest : public ::testing::Test { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - static std::unique_ptr CreateNewModule() { + static std::unique_ptr CreateNewUnverifiedModule() { HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + config.set_debug_options(GetDebugOptionsFromFlags()); return absl::make_unique(TestName(), config); } }; @@ -150,12 +151,12 @@ TEST_F(GpuCompilerTest, HooksTest) { TestCompilerHooks(&compiler); } -TEST_F(CpuCompilerTest, MultiModuleCompilation) { +TEST_F(CpuCompilerTest, CpuMultiModuleCompilation) { cpu::CpuCompiler compiler; TestMultiModuleCompilation(&compiler); } -TEST_F(GpuCompilerTest, MultModuleCompilation) { +TEST_F(GpuCompilerTest, NVPTXMultiModuleCompilation) { gpu::NVPTXCompiler compiler; TestMultiModuleCompilation(&compiler); } diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 56aaeb0e6878737e6c689e8065d8f1e1871b3472..3f5135438fc59bea98527b1be30ee49339edd455 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -62,7 +62,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); const Shape elem_shape2 = @@ -122,7 +122,7 @@ class MultiOutputFusionTest : public HloTestBase { void RunTest1D(bool manual_fusion, int size) { auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); + auto hlo_module = CreateNewUnverifiedModule(); const Shape elem_shape_F32 = ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); @@ -192,7 +192,7 @@ XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { const char* testcase = R"( - HloModule m + HloModule m, is_scheduled=true fused_computation { x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) @@ -224,7 +224,7 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { const char* testcase = R"( - HloModule m + HloModule m, is_scheduled=true fused_computation { p = f32[4] parameter(0) @@ -251,7 +251,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { const char* testcase = R"( - HloModule m + HloModule m, is_scheduled=true fused_computation { p = f32[] parameter(0) @@ -282,7 +282,7 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { } const char* const kScalarOps = R"( - HloModule m + HloModule m, is_scheduled=true Add { lhsadd = f32[] parameter(0) diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 58539e6b061b0cec1cc660b52e78894e5deeea56..774eb8d2a85914c52597144e70838ee117ee1134 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -87,8 +87,8 @@ TEST_F(PredTest, ConstantR2Pred) { XlaBuilder builder(TestName()); ConstantR2(&builder, {{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { - { 011 }, - { 100 } + { 0, 1, 1 }, + { 1, 0, 0 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 26e2bfde5cdc19657640f24f31bc008d09ad7106..f80d29b9de440b11c36e8c9bc65d4a93353a6267 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -216,14 +216,13 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { const uint32_t sign_bit = 1u << 31; for (const auto& test_value : test_values) { // Add positive values. - input_values.push_back(tensorflow::bit_cast(test_value[0])); - expected_values.push_back(tensorflow::bit_cast(test_value[index])); + input_values.push_back(absl::bit_cast(test_value[0])); + expected_values.push_back(absl::bit_cast(test_value[index])); // Add negative values. We do this in the bitwise representation so as to // avoid problems with NaN handling. - input_values.push_back( - tensorflow::bit_cast(test_value[0] ^ sign_bit)); + input_values.push_back(absl::bit_cast(test_value[0] ^ sign_bit)); expected_values.push_back( - tensorflow::bit_cast(test_value[index] ^ sign_bit)); + absl::bit_cast(test_value[index] ^ sign_bit)); } // This is required for proper handling of NaN values. @@ -283,7 +282,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -301,7 +300,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); - ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.00001f, -1.00001f}, {a_data.get()}); } // The interpreter has no fusion pass, so skip this test. @@ -309,7 +308,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -325,7 +324,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; }); - ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.0f, -1.0f}, {a_data.get()}); } // The interpreter has no fusion pass, so skip this test. @@ -358,7 +357,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - Literal a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001, 1.00001}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a = Parameter(&builder, 0, a_literal.shape(), "a"); @@ -375,7 +374,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); - ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); + ComputeAndCompareR1(&builder, {-1.0f, -1.0f}, {a_data.get()}); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 83997cdac21c437d460dabdbdfdb31100b1359af..18c99490a387923aaf68e06041cd11ed3b954aa5 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -980,5 +981,25 @@ XLA_TEST_F(ReduceTest, OrReduceU64) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ReduceTest, R0ReduceInDisguise) { + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); + constexpr int element_count = 127; + const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count, 1}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto zero = ConstantR0(&builder, 0.0); + Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array2D input_data(element_count, 1); + input_data.FillRandom(3.0f); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + std::unique_ptr input_global_data = + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + + float expected = absl::c_accumulate(input_data, 0.0f); + ComputeAndCompareR1(&builder, {expected}, {input_global_data.get()}, + ErrorSpec(0.001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index c25ccafaf83cf1b29095a77eefa357d9af08dc60..22fe4a2670e2e0e1fedc45036a1ceec19f44e42e 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, /*padding=*/padding); CHECK(reducer == kAdd || reducer == kMax); @@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1369,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 5cf87e565bf493167f5173588e7afa3b96282488..34c7dc7c46427b2d18ea21fc286ee03175f70800 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = @@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. std::unique_ptr x_data = @@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index dedc95b5ae8315185a35f786af42aad53bd7ad96..298136002e9ef47188e0bae95af3f596596e6062 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -618,7 +618,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, - {1, 0}); + {1, 0}) + .ToProto(); Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) @@ -767,7 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, - {2, 3, 0, 1}); + {2, 3, 0, 1}) + .ToProto(); Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 091a5d2cacce6ac5bf986776e5ec96612d08cc75..606a099ecbc4e5677034c6d57e7ba5c398c69ab9 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -47,7 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase { TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { string data(sizeof(float) * 2, 0); - absl::Span floats(tensorflow::bit_cast(data.data()), 2); + absl::Span floats(absl::bit_cast(data.data()), 2); floats[0] = 42.0; floats[1] = 24.0; @@ -69,7 +69,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { string data(sizeof(float) * 4, 0); - absl::Span floats(tensorflow::bit_cast(data.data()), 4); + absl::Span floats(absl::bit_cast(data.data()), 4); // With x as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=0,x=1 @@ -102,7 +102,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { string data(sizeof(float) * 4, 0); - absl::Span floats(tensorflow::bit_cast(data.data()), 4); + absl::Span floats(absl::bit_cast(data.data()), 4); // With y as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=1,x=0 diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index b21dd56045e1dc11847e213852dea60cd033be7b..32de0fdf78f9c442e17c55e1b951e39122dac5ef 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -69,6 +69,37 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_WithFusedAdds) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + p0 = s32[3,3] parameter(0) + operand = s32[3,3] add(p0, p0) + p1 = s32[2] parameter(1) + indices = s32[2] add(p1, p1) + p2 = s32[2,3] parameter(2) + updates = s32[2,3] add(p2, p2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { const char* hlo_text = R"( HloModule TensorFlowScatterV2 @@ -98,6 +129,73 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_InversePermutation) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + permutation = s32[3,4] parameter(0) + reshape = s32[3,4,1] reshape(permutation) + operand = s32[3,4] iota(), iota_dimension=1 + updates = s32[3,4,1,1] iota(), iota_dimension=1 + iota = s32[3,4,1] iota(), iota_dimension=0 + indices = s32[3,4,2] concatenate(iota, reshape), dimensions={2} + ROOT scatter = s32[3,4] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={2,3}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=2 +} +)"; + Literal permutation = + LiteralUtil::CreateR2({{1, 3, 2, 0}, {3, 0, 2, 1}, {2, 3, 1, 0}}); + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + auto actual = ExecuteAndTransfer(std::move(module), {&permutation}); + Literal expected = + LiteralUtil::CreateR2({{3, 0, 2, 1}, {1, 3, 2, 0}, {3, 2, 0, 1}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); +} + +XLA_TEST_F(ScatterTest, SimpleR4) { + const char* hlo_text = R"( +HloModule SimpleR4 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[1,2,2,1] parameter(0) + indices = s32[1,3] parameter(1) + updates = f32[1,2,2,1] parameter(2) + ROOT scatter = f32[1,2,2,1] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1,2,3}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0, 2, 1}, + index_vector_dim=1 +} +)"; + + Literal operand = + LiteralUtil::CreateR4({{{{0.f}, {0.f}}, {{0.f}, {0.f}}}}); + Literal updates = + LiteralUtil::CreateR4({{{{0.12}, {0.28}}, {{0.018}, {0.42}}}}); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0, 0}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { const string hlo_text = R"( HloModule TensorFlowScatter_Add diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 2cc33ab0963afe8ba2d8e9a6972dcf0622e27c48..3fb69419e735bfd9c5054673e0687f5139a410cb 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -166,6 +166,26 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } +TEST_F(SliceTest, SliceOfReshape) { + Array2D values(2 * 3 * 24, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR2FromArray2D(&builder, values); + auto reshape = Reshape(original, {24, 3, 2, 7}); + Slice(reshape, {0, 0, 0, 0}, {11, 3, 2, 7}, {1, 1, 1, 1}); + ComputeAndCompare(&builder, {}); +} + +TEST_F(SliceTest, SliceOfCollapsingReshape) { + Array4D values(2, 3, 5, 7); + values.FillIota(1); + XlaBuilder builder(TestName()); + auto original = ConstantR4FromArray4D(&builder, values); + auto reshape = Reshape(original, {2 * 3 * 5, 7}); + Slice(reshape, {0, 0}, {4, 7}, {1, 1}); + ComputeAndCompare(&builder, {}); +} + XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { Array4D values(2, 4, 6, 8); values.FillRandom(3.14f); @@ -253,7 +273,6 @@ XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } - // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 5155f0c652c7c6dbba60c421159494fa28072090..eafa48ed7b8cf2bd67fe767ad36082661dbbd66e 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/base/casts.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -28,65 +29,113 @@ namespace xla { namespace { template -void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } +} + +template +void PopulateWithIntNext(Literal* literal); + +template <> +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + uint16 next_value = 0; + for (half& value : literal->data()) { + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.x = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; + } +} + +template <> +void PopulateWithIntNext(Literal* literal) { + // Duplicates may be generated if we don't have enough bits. + // Start at 0x80 rather than 0 to avoid denormals. + uint16 next_value = 0x80; + for (bfloat16& value : literal->data()) { + // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into + // the sign bit. We could be less wasteful, but this is best-effort anyway. + uint16 exponent_msb = next_value & 0x4000; + value.value = (next_value & 0xBFFF) | (exponent_msb << 1); + next_value++; + } +} + +template +void PopulateWithNextAfter(Literal* literal) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + float next_value = std::numeric_limits::min(); + for (float& value : literal->data()) { + value = next_value; + next_value = std::nextafter(next_value, std::numeric_limits::max()); + } +} + +template ::value || + std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithIntNext(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template ::value && + !std::is_same::value, + int>::type = 0> +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithNextAfter(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +template +void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); if (no_duplicates) { - // Duplicates may be generated if the number of elements in the literal - // exceeds the number of positive values supported by the type. - FloatT next_value = std::numeric_limits::min(); - for (FloatT& value : literal->data()) { - value = next_value; - next_value = - std::nextafter(next_value, std::numeric_limits::max()); - } - std::shuffle(literal->data().begin(), literal->data().end(), - *engine); + PopulateWithNoDuplicateData(literal, engine); } else { - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (FloatT& value : literal->data()) { - value = static_cast(generator(*engine)); - } + PopulateWithRandomFloatingPointData(literal, engine); } } -template -void PopulateWithRandomFloatingPointData(Literal* literal, +template <> +void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine, bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine, - no_duplicates); -} - -template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for half types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. - CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (half& value : literal->data()) { - value = static_cast(generator(*engine)); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); } } template <> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates) { - // no_duplicates is ignored for bfloat types. Unique values can only be - // generated for arrays with fewer than ~2**16 elements and no_duplicates is - // best-effort anyway. +void PopulateWithFloatingPointData(Literal* literal, + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (bfloat16& value : literal->data()) { - value = static_cast(generator(*engine)); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); } } @@ -135,20 +184,16 @@ StatusOr MakeFakeLiteralInternal(const Shape& shape, Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(&literal, engine, - no_duplicates); + PopulateWithFloatingPointData(&literal, engine, no_duplicates); break; case S8: PopulateWithRandomIntegralData(&literal, engine, no_duplicates); @@ -272,9 +317,11 @@ std::vector FindConstrainedUses( constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && - instruction->operand_count() == 2 && op_num == 0) { + instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value - // (two-operand) kSort instructions. + // (two-operand) kSort instructions. Since sort stability is not + // guaranteed, constrain keys of key-value sort not to have duplicates, + // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index bc433eac8fcb02087d8e4eb10f638c85dc141b22..e8f5d7a9a79ebddea3cb989dbe8eab90b630d5e7 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/base/casts.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -148,7 +148,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( absl::flat_hash_set key_set; for (const float& value : key_arg.data()) { - EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + EXPECT_TRUE(key_set.insert(absl::bit_cast(value)).second); } } @@ -171,7 +171,30 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( absl::flat_hash_set key_set; for (const int32& value : key_arg.data()) { - EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + EXPECT_TRUE(key_set.insert(absl::bit_cast(value)).second); + } +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort, is_scheduled=true + +ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) { + %parameter.0 = bf16[2,1452]{1,0} parameter(0) + %parameter.1 = s32[2,1452]{1,0} parameter(1) + ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = args[0]; + + absl::flat_hash_set key_set; + for (const bfloat16& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(absl::bit_cast(value)).second); } } diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index b34fd0f2e873214c509533f29553af914ddc984d..601c6b06938fef1f1ae809b33209ae59b24c70a2 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,7 +29,7 @@ namespace { class TokenHloTest : public HloTestBase {}; XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateToken()); @@ -38,8 +39,22 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } +XLA_TEST_F(TokenHloTest, TokenInTuple) { + std::unique_ptr module = CreateNewUnverifiedModule(); + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateTuple({token})); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + Literal token_literal = LiteralUtil::CreateToken(); + EXPECT_TRUE( + LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal}))); +} + XLA_TEST_F(TokenHloTest, TokenTree) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); @@ -54,7 +69,7 @@ XLA_TEST_F(TokenHloTest, TokenTree) { } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); @@ -75,7 +90,7 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { } XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { - std::unique_ptr module = CreateNewModule(); + std::unique_ptr module = CreateNewUnverifiedModule(); auto builder = HloComputation::Builder(TestName()); builder.AddInstruction(HloInstruction::CreateParameter( 0, @@ -94,26 +109,6 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); } -XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { - std::unique_ptr module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); - builder.AddInstruction(HloInstruction::CreateAfterAll({param})); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); - module->AddEntryComputation(builder.Build()); - - Status status = - HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) - .Run(module.get()) - .status(); - ASSERT_IS_NOT_OK(status); - EXPECT_THAT(status.error_message(), - ::testing::HasSubstr( - "Operands of token instructions must be TOKEN types")); -} - XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { // Thread a token around a while loop. Token is created and consumed by a // AfterAll instruction in the while body. @@ -206,5 +201,95 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { } } +XLA_TEST_F(TokenHloTest, AddDependency) { + string module_string = R"( +HloModule AddDependency, is_scheduled=true + +// Computes (p0 + 42) * (-p1) +// where there is a dependency from the add to the negation using a token +// with after-all and add-dependency instructions. +ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + + %forty_two = f32[] constant(42.0) + %add = f32[] add(f32[] %p0, f32[] %forty_two) + %token = token[] after-all(f32[] %add) + %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token) + %neg = f32[] negate(f32[] %p1_after_token) + ROOT %product = f32[] multiply(f32[] %add, f32[] %neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto p1 = LiteralUtil::CreateR0(3.0); + auto expected = LiteralUtil::CreateR0(-156.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) { + string module_string = R"( +HloModule AddDependencyOfConstant, is_scheduled=true + +ENTRY %AddDependency (p0: f32[]) -> f32[] { + %p0 = f32[] parameter(0) + %forty_two = f32[] constant(42.0) + %token = token[] after-all(f32[] %p0) + %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token) + ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR0(10.0); + auto expected = LiteralUtil::CreateR0(420.0); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0})); +} + +XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) { + string module_string = R"( +HloModule AddDependencyAsRoot, is_scheduled=true +ENTRY %AddDependency (p: f32[3]) -> f32[3] { + %p = f32[3] parameter(0) + %neg = f32[3] negate(f32[3] %p) + %token = token[] after-all() + ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto input = LiteralUtil::CreateR1({1.0, 3.0, 7.0}); + auto expected = LiteralUtil::CreateR1({-1.0, -3.0, -7.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input})); +} + +XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) { + string module_string = R"( +HloModule TupleShapedAddDependency, is_scheduled=true +ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] { + %p0 = f32[3] parameter(0) + %p1 = f32[3] parameter(1) + %forty_two = f32[] constant(42.0) + %token = token[] after-all() + %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two) + %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token) + %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0 + %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2 + ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseHloString(module_string, GetModuleConfigForTest())); + auto p0 = LiteralUtil::CreateR1({3.0, 3.0, 47.0}); + auto p1 = LiteralUtil::CreateR1({1.0, -2.0, 2.0}); + auto expected = LiteralUtil::CreateR1({2.0, 5.0, 45.0}); + EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 619d2a388b5646c31f0a61f709a2ab3067e39c03..27ce243e9bd4afbdcc1fdc5b6873d4968086e459 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -515,7 +515,7 @@ class TupleHloTest : public HloTestBase {}; // Disabled on the interpreter because bitcast doesn't exist on the interpreter. XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { const char* testcase = R"( - HloModule m + HloModule m, is_scheduled=true ENTRY test { name.1 = (f32[3]{0}) parameter(0) diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 8b1b9e151992296b9d022ae1d9d974eadd2074a8..6d5f276e82087cedc356691b0ff08df24cec8d20 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -48,7 +48,7 @@ class WhileTest : public ClientLibraryTestBase {}; // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS32Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -84,7 +84,7 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarS64Result) { +XLA_TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -114,7 +114,7 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { +XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto result_shape = ShapeUtil::MakeShape(S32, {}); auto orig_shape = ShapeUtil::MakeShape(S32, {2}); @@ -147,7 +147,7 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithPredicateResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. @@ -184,7 +184,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { // while (result.sum() < 15.5f) { // result = result + vector(0); // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. @@ -238,7 +238,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) { // while (result.sum() < 15.5f) { // result = result + vector(8, 0.125f); // } -TEST_F(WhileTest, WhileWithVectorResult) { +XLA_TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -298,7 +298,7 @@ TEST_F(WhileTest, WhileWithVectorResult) { // result = result + vector(8, 0.125f); // } // tuple = tuple { while } -TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { +XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. @@ -353,7 +353,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -407,7 +407,7 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { +XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -465,7 +465,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // get<0>(result) = get<0>(result) + 1; // get<1>(result) = get<1>(result) + vector(10, 1.0f); // } -TEST_F(WhileTest, WhileWithTupleResult) { +XLA_TEST_F(WhileTest, WhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -514,7 +514,7 @@ TEST_F(WhileTest, WhileWithTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileWithPredicateTupleResult) { +XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(PRED, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -560,7 +560,7 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } -TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { +XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -619,7 +619,7 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // get<1>(w1) = get<1>(w1) + vector(10, 1.0f); // } // result = get<1>(w0) + get<1>(w1) -TEST_F(WhileTest, TwoWhileWithTupleResult) { +XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -698,7 +698,7 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { } // Test while nodes that share the while body computation. -TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { +XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -763,7 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } -TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { +XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); @@ -901,7 +901,7 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. @@ -947,7 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { } } -TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -979,7 +979,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { +XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); XlaBuilder outer("outer"); @@ -1004,7 +1004,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { ErrorSpec(1e-6)); } -TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { +XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); XlaBuilder outer("outer"); @@ -1038,7 +1038,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { // result[0] = result[0] + 1; // result[1] = result[1] + 1; // } -TEST_F(WhileTest, WhileWithMixedTupleElements) { +XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); @@ -1146,7 +1146,7 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -1186,7 +1186,7 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { ComputeAndCompareR0(&builder, 5, {}); } -TEST_F(WhileTest, WhileWithLoopInvariantOperation) { +XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) { auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); auto while_shape = ShapeUtil::MakeTupleShape( @@ -1230,7 +1230,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {param_value.get()}, ErrorSpec(4e-5)); } -TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { +XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { auto while_shape = ShapeUtil::MakeShape(S32, {}); XlaComputation condition; diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index a6e70eb6ca25ffac24a8ebaf0420238e109e4fad..e57d072a0632b492b8b6e34439f4e80332b843b6 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -91,16 +91,16 @@ Status ParseOneProfileOutputLine( string match_usecs = "([0-9.]+) usec"; string match_flops = "([^ ]*)"; string match_trops = "([^ ]*)"; - string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; - string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + string match_bytes_per_sec = "([0-9.TGMKi]*)(?:B/s)?"; + string match_bytes_per_cycle = "([0-9.TGMKi]*)(?:B/cycle)?"; // The underlined part is what we're trying to match with match_opcode: // // %dot33 = f32[256,256]{1,0} dot(...) // ^^^ - string match_opcode = - expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; + string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" + : "(\\[total\\])( \\[entry\\])?"; string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, @@ -125,6 +125,10 @@ Status ParseOneProfileOutputLine( return Status::OK(); } +bool IsExtraMetricProfileOutputLine(const string& line) { + return RE2::FullMatch(line, "Extra metric \\S+: \\d+"); +} + // Returns void so that we can ASSERT. void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, const XlaComputation& computation, @@ -153,10 +157,12 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + ExecutableBuildOptions build_options; + build_options.mutable_debug_options()->set_xla_hlo_profile(true); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions().set_hlo_profile(true))); + build_options)); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -204,20 +210,32 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape, rhs_shape); - + VLOG(4) << "Profile Output:\n" << profile_output; std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); absl::flat_hash_map parsed_profile_lines; - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); + int line_no = 0; + + // Skip extra metrics. + while (IsExtraMetricProfileOutputLine(profile_output_lines[line_no])) { + line_no++; + } + + line_no++; // Skip 'Execution profile for ....' - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines)); + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/false, + &parsed_profile_lines)); - TF_ASSERT_OK(ParseOneProfileOutputLine( - profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines)); + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/true, + &parsed_profile_lines)); + + TF_ASSERT_OK(ParseOneProfileOutputLine(profile_output_lines[line_no++], + /*expect_hlo=*/true, + &parsed_profile_lines)); TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile, MaybeFind(parsed_profile_lines, "[total]")); @@ -291,6 +309,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { string profile_output; ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, matrix_shape); + SCOPED_TRACE(profile_output); std::vector profile_output_lines = absl::StrSplit(profile_output, '\n'); @@ -302,14 +321,13 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = std::find_if( - while_body_profile_start, profile_output_lines.end(), - [](absl::string_view s) { - return absl::StartsWith(s, "********** microseconds report **********"); - }); + auto while_body_profile_end = + std::find_if(while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds "); + }); - // We emit a blank line before the "********** microseconds report **********" - // line. + // We emit a blank line before the "microseconds report" line. while_body_profile_end--; ASSERT_NE(while_body_profile_end, profile_output_lines.end()); @@ -364,7 +382,7 @@ static std::pair AddXlaHloProfileFlag(int argc, char** argv) { GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 15603619b62d8f45cdce97ac7d83924a78f88cf3..dca0aa52a533130372759156a3238f1a3b10ca42 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -15,14 +15,14 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); auto usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { LOG(ERROR) << "\n" << usage; @@ -49,7 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { // different API than Tensorflow's. testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) - base::SetFlag(&FLAGS_benchmarks, pattern); + absl::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); #else tensorflow::testing::Benchmark::Run(pattern); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 3a086c66bbb37965b1ad7c83a93f0054ae723e87..8926bbed2b54fceaaf0e6e991f0e881d35731ef4 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -33,6 +33,7 @@ cc_library( name = "dumped_computation_to_graphviz_library", srcs = ["dumped_computation_to_graphviz.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -40,7 +41,6 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", @@ -78,6 +78,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -91,7 +92,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", @@ -207,13 +207,13 @@ tf_cc_binary( name = "dumped_computation_to_tf_graphdef", srcs = ["dumped_computation_to_tf_graphdef.cc"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index c866a13de7543fc948311f94708bc6b904717b62..b623556468fb4a5d96be614b6c067d5a1df51a6f 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -54,7 +54,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = client->GetComputationStats(computation, debug_options) @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 07ef5ff656bb48519a700a1d7d6c60b655a40ed6..f8bb9a6b1e217fc4e6e15c8a3302be61ed339c82 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -53,7 +53,7 @@ void RealMain(absl::Span args) { tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); XlaComputation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + DebugOptions debug_options = GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); ComputationStats stats = @@ -68,7 +68,7 @@ void RealMain(absl::Span args) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 0c41f227b31ebe1f01073785ea2a666093aefdb3..ff2c3399928c0e6339304323c4f93e212933a340 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -47,8 +47,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -82,12 +82,17 @@ struct Options { std::unique_ptr CompileExecutable(const HloSnapshot& module, LocalClient* client) { XlaComputation computation(module.hlo().hlo_module()); - std::vector argument_layouts; - for (const auto& param : computation.proto().program_shape().parameters()) { - argument_layouts.push_back(¶m); + std::vector argument_layouts; + argument_layouts.reserve( + computation.proto().host_program_shape().parameters_size()); + std::vector argument_layout_ptrs; + for (const ShapeProto& param : + computation.proto().host_program_shape().parameters()) { + argument_layouts.push_back(Shape(param)); + argument_layout_ptrs.push_back(&argument_layouts.back()); } return client - ->Compile(computation, argument_layouts, ExecutableBuildOptions()) + ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions()) .ValueOrDie(); } @@ -148,7 +153,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "--generate_fake_infeed only works if the model has 0 or 1 " "infeed ops, but this one has >= 2."; provide_infeed = true; - infeed_shape = instruction.shape(); + infeed_shape = Shape(instruction.shape()); LOG(INFO) << "Generating fake infeed shape for inferred shape: " << ShapeUtil::HumanString(infeed_shape); } @@ -190,16 +195,16 @@ StatusOr ReplayComputation(const HloSnapshot& module, // Run the computation num_runs times, and return the result from the last // execution. - const bool xla_hlo_profile = - legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); + const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile(); StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); - absl::optional result; + absl::optional final_result; for (int i = 0; i < opts.num_runs; ++i) { // If xla_hlo_profile is enabled, print a noisy message before the last run, // making it easier to separate this profile from the others in the logspam. - if (xla_hlo_profile && i == opts.num_runs - 1) { + bool is_final_result = i == opts.num_runs - 1; + if (xla_hlo_profile && is_final_result) { LOG(INFO) << "\n\n***** Final run below ******"; } ExecutionProfile profile; @@ -207,14 +212,22 @@ StatusOr ReplayComputation(const HloSnapshot& module, run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); - TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Done executing in " << static_cast(profile.compute_time_ns()) / 1e9 << "s: " << module.hlo().hlo_module().name(); + + // Save the result if this is for the final iteration. Otherwise discard + // the result before rerunning the computation, so as to free up the + // relevant memory. + if (is_final_result) { + final_result = std::move(result); + } } TF_ASSIGN_OR_RETURN(Literal result_literal, - client->ShapedBufferToLiteral(*result)); + client->ShapedBufferToLiteral(*final_result)); return result_literal; } @@ -306,9 +319,10 @@ int RealMain(absl::Span args, const Options& opts) { if (snapshot.has_result()) { Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal.ToString().c_str()); + fprintf( + stdout, "was %s:%s\n", + ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(), + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 8ce741647414a1fa75e6d706ec1e719ace7b7cc8..6722641e9d2c177440361e6f0d1f6c0804eb7cda 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -152,6 +152,13 @@ static inline absl::Span AsInt64Slice( slice.size()); } +// TODO(b/29771030): This nop overload was added to simplify the migration of +// Shape from a proto to a C++ class. Remove after class has been migrated. +static inline absl::Span AsInt64Slice( + absl::Span slice) { + return slice; +} + // As above, but for uint64 types. static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { @@ -387,6 +394,19 @@ T CeilOfRatio(T dividend, T divisor) { return tensorflow::MathUtil::CeilOfRatio(dividend, divisor); } +template +std::vector ElementWiseCeilOfRatio(absl::Span dividends, + absl::Span divisors) { + std::vector ceil_of_ratios; + CHECK_EQ(dividends.size(), divisors.size()); + ceil_of_ratios.reserve(dividends.size()); + absl::c_transform(dividends, divisors, std::back_inserter(ceil_of_ratios), + [](const T dividend, const T divisor) { + return CeilOfRatio(dividend, divisor); + }); + return ceil_of_ratios; +} + // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 8ea8dbab2574ca1e24271e7c1c7762d4a6b6a8de..51c73b3d17e4c32d9a8a14d3055ab56f02922af3 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -185,6 +185,17 @@ bool HasWindowReversal(const Window& window) { return false; } +bool AllOrNoneReversed(const Window& window) { + if (window.dimensions().empty()) { + return true; + } + bool reversed = window.dimensions()[0].window_reversal(); + return std::all_of(window.dimensions().begin(), window.dimensions().end(), + [&](const WindowDimension& dim) { + return dim.window_reversal() == reversed; + }); +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 1fb9e855fc16f334eb0e83dfd27b307b2149628f..099d7ecdd5c732ffc8c6ff6370288a2fc4144fa2 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -56,6 +56,7 @@ bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); bool HasWindowReversal(const Window& window); +bool AllOrNoneReversed(const Window& window); // Returns true if the given logical dimension is inactive in the sense that it // has window bound 1, no striding and no padding. diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 60d25a6407476cddba77aadd1df2e3939f5e40ac..a37eac7fe441d91aa71e1b6fd7b84099fee2215b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -193,7 +193,11 @@ message DebugOptions { // - Assuming that operations never produce or consume NaN or +/- Inf. // - Assuming that +0 and -0 are indistinguishable. bool xla_cpu_enable_fast_math = 99; - bool xla_gpu_enable_fast_math = 100; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + bool xla_gpu_enable_fast_min_max = 100; // Crashes the program when any kind of verification fails, instead of just // logging the failures. One example is cross checking of convolution results @@ -209,6 +213,9 @@ message DebugOptions { // the host that run models in parallel across multiple devices. int32 xla_force_host_platform_device_count = 102; + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_ptxas_optimizations = 103; + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map xla_backend_extra_options = 500; @@ -224,7 +231,7 @@ message ExecutionOptions { // may be faster when using this layout. // // We use a Shape here to accommodate computations that return a tuple. - Shape shape_with_output_layout = 2; + ShapeProto shape_with_output_layout = 2; // Used to seed random-number generators used in this computation. If this is // 0, we generate a seed ourselves. @@ -253,7 +260,7 @@ message TransferToClientRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 2; + ShapeProto shape_with_layout = 2; } message TransferToClientResponse { @@ -281,7 +288,7 @@ message TransferToInfeedResponse { message TransferFromOutfeedRequest { // This optional field directs the service to return the literal in this // layout. A shape is used to hold the layout to accommodate tuples. - Shape shape_with_layout = 1; + ShapeProto shape_with_layout = 1; int64 replica_id = 2; DeviceHandle device_handle = 3; @@ -316,12 +323,40 @@ message CreateChannelHandleResponse { } message UnregisterRequest { - GlobalDataHandle data = 1; + repeated GlobalDataHandle data = 1; } message UnregisterResponse { } +message CompileRequest { + // The graph to be compiled. + HloModuleProto computation = 1; + + // Options that affect how XLA compiles code to service this request. + ExecutionOptions execution_options = 2; + + // The layouts of the input arguments. If not set, the default layout will be + // used. Although the real arguments are not needed in compilation, the + // layouts of the arguments can affect the compilation. + repeated ShapeProto input_shape_with_layout = 3; +} + +message CompileResponse { + // The handle to the executable. + ExecutionHandle handle = 1; +} + +message ExecuteRequest { + ExecutionHandle handle = 1; + + // The shape and layout of the arguments must be the same as the those of the + // executable's parameters. + repeated GlobalDataHandle arguments = 2; +} + +// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace +// the uses with calls to Compile and Execute. message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -378,7 +413,7 @@ message LoadDataRequest { string columnio_field = 2; // Individual element shape, excluding rows. - Shape element_shape = 3; + ShapeProto element_shape = 3; // Warning: ColumnIO does not support random-access, so use offset with // caution in performance-critical scenarios. @@ -394,7 +429,7 @@ message LoadDataRequest { message LoadDataResponse { GlobalDataHandle data = 1; - Shape data_shape = 2; + ShapeProto data_shape = 2; int64 available_rows = 3; int64 rows_loaded = 4; int64 nanoseconds = 5; @@ -405,7 +440,7 @@ message GetShapeRequest { } message GetShapeResponse { - Shape shape = 1; + ShapeProto shape = 1; } message UnpackRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 73b3589dbf12341ddb3f3e819a550467a7b4d166..85ec83437a10d973687a7fb84285c2e2541a53c7 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -78,28 +78,6 @@ enum PrimitiveType { // Next = 18 } -// Describes the value held inside padding elements. -enum PaddingValue { - INVALID_PAD = 0; - - // Zero padding must be 0-values that correspond to the shape's element type. - ZERO_PAD = 1; - - // One padding must be 1-values that correspond to the shape's element type. - ONE_PAD = 2; - - // "Lowest" padding must be the lowest values in the shape's element type, - // used as padding for operations like max-accumulation. - LOWEST_PAD = 3; - - // "Highest" padding must be the largest values in the shape's element type, - // used as padding for operations like min-accumulation. - HIGHEST_PAD = 4; - - // Unknown padding could be anything; e.g. floating NaNs! - UNKNOWN_PAD = 5; -} - // Describes the padding configuration for Pad operation. The padding amount on // both edges as well as between the elements are specified for each dimension. message PaddingConfig { @@ -123,17 +101,25 @@ message PaddingConfig { // A format specifies the method used by a layout to store an array in memory. enum Format { INVALID_FORMAT = 0; - // The default layout, with exactly one storage location per element (ignoring - // padding). + // The default layout, with exactly one storage location per element. DENSE = 1; // A sparsely encoded layout, providing only the index/value pairs of non-zero // elements. SPARSE = 2; } +// Describes a tile used in tiling-based layout. Refer to +// g3doc/layout_with_tiling.md for details about tiling-based layout. +message Tile { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + // A layout describes how the array is placed in (1D) memory space. This -// includes the minor-to-major ordering of dimensions within a shape, as well as -// any padding present in those dimensions. +// includes the minor-to-major ordering of dimensions within a shape. // // Clients must specify the layouts of input Literals to the // computation. Layouts specified in interior operations which take Shapes (for @@ -151,22 +137,31 @@ message Layout { // (slowest varying index). This field is required. repeated int64 minor_to_major = 1; - // The width to which the layout of each dimension is padded up to. If - // present, the size of the padded_dimensions must equal the rank of the - // shape. The padding appears at the end of a dimension, not at the - // beginning. This kind of padding, unlike padding in e.g. convolution, is not - // part of the shape. This field must be unset unless the format is DENSE. - repeated int64 padded_dimensions = 2; + reserved 2; + reserved "padded_dimensions"; - // Describes the values in the padding specified by padded_dimensions. This - // field must be unset unless the format is DENSE. - PaddingValue padding_value = 3; + reserved 3; + reserved "padding_value"; // The maximum number of elements that can be stored for SPARSE formats. This // can be used to determine the maximum size in bytes of arrays stored in // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated Tile tiles = 6; + + // Bit size of each element. If the size is bigger than what the element + // type requires, the value is stored in the least significant + // bits and the additional most significant bits are filled with 0's. + // + // TODO(b/119839262): implement in each backend or add Unimplemented error. + int64 element_size_in_bits = 7; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and // LayoutUtil::Hash appropriately to account for the new field. } @@ -183,7 +178,7 @@ message Layout { // See the XLA documentation for more information on shapes and layouts. // // LINT.IfChange -message Shape { +message ShapeProto { reserved 1; reserved "rank"; @@ -198,7 +193,7 @@ message Shape { repeated int64 dimensions = 3; // For tuples only, the shapes of constitutent shapes in the tuple sequence. - repeated Shape tuple_shapes = 4; + repeated ShapeProto tuple_shapes = 4; // The layout used to back this shape. Layout layout = 5; @@ -212,9 +207,9 @@ message Shape { // Shape of the parameters and output of a computation (like a traditional // function signature). -message ProgramShape { - repeated Shape parameters = 1; - Shape result = 2; +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; repeated string parameter_names = 3; } @@ -349,7 +344,7 @@ message DeviceAssignmentProto { // Transfers to/from the client are encoded in literal form, and the structure // of the repeated fields is implied by the shape. message LiteralProto { - Shape shape = 1; + ShapeProto shape = 1; repeated bool preds = 2; bytes s8s = 15; bytes u8s = 3; @@ -361,11 +356,13 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - // The F16s and BF16s are encoded in little endian byte order + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; repeated int64 sparse_indices = 14; - // Next = 16 + // Next = 18 } message WindowDimension { @@ -548,7 +545,7 @@ message OpSharding { } Type type = 1; // The shape of the sharded tile. - Shape tile_shape = 2; + ShapeProto tile_shape = 2; // The shape of the tile assignment tensor - this must be the same rank as // tile_shape and the product of its dimensions must equal // tile_assignment_devices.size(). diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 2ff97914f862e0ec30fc54602ec5fee2a0a5ebca..2dae746d034a1bf52e84de74dfb0c6e23aaed4d1 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -22,6 +22,7 @@ xla_proto_library( deps = [ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo_proto", ], ) @@ -32,20 +33,25 @@ cc_library( "xrt_compilation_cache.cc", "xrt_device.cc", "xrt_state.cc", + "xrt_util.cc", ], hdrs = [ "xrt_compilation_cache.h", "xrt_device.h", "xrt_state.h", + "xrt_util.h", ], deps = [ "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:device_memory_allocator", diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 9e3d2454d16730c1d1f93cb384db88544380f77e..67f475846e5f16060c1080759b0acb4216c4e72b 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["xrt_state_ops.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -21,7 +22,6 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:hlo_proto", diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 1d4f8d97f2ed8b263878b94b365b7fb5b949b1a2..2ccdf0f02d840600d5e0649c4805e3672d4a1286 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt.pb.h" #include "tensorflow/compiler/xrt/xrt_compilation_cache.h" #include "tensorflow/compiler/xrt/xrt_device.h" +#include "tensorflow/compiler/xrt/xrt_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" @@ -108,19 +109,26 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx, TF_ASSIGN_OR_RETURN(xla::XlaComputation computation, client->LoadSnapshot(computation_proto.hlo_snapshot())); - std::vector argument_layouts( + std::vector argument_layouts( + config.program_shape().parameters_size()); + std::vector argument_layout_ptrs( config.program_shape().parameters_size()); for (int i = 0; i < config.program_shape().parameters_size(); ++i) { - argument_layouts[i] = &config.program_shape().parameters(i); + argument_layouts[i] = xla::Shape(config.program_shape().parameters(i)); + argument_layout_ptrs[i] = &argument_layouts[i]; } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client->default_device_ordinal()); - build_options.set_result_layout(config.program_shape().result()); + build_options.set_result_layout(xla::Shape(config.program_shape().result())); build_options.set_device_allocator(device_ref.backend()->memory_allocator()); + if (config.has_debug_options()) { + *build_options.mutable_debug_options() = + BuildXlaDebugOptions(config.debug_options()); + } VLOG(1) << "Building executable"; auto compile_result = - client->Compile(computation, argument_layouts, build_options); + client->Compile(computation, argument_layout_ptrs, build_options); if (!compile_result.ok()) { return compile_result.status(); } @@ -166,10 +174,23 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Compiling XLA executable"; return Compile(ctx, computation_proto, program); })); - - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = uid; - ctx->set_output(0, output); + std::unique_ptr entry; + OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry)); + + Tensor handle_output(DT_INT64, TensorShape({})); + handle_output.scalar()() = uid; + ctx->set_output(0, handle_output); + + xla::LocalExecutable* executable = entry->get().get_executable(); + xla::ProgramShapeProto program_shape = executable->executable() + ->module() + .config() + .entry_computation_layout() + .ComputeProgramShape() + .ToProto(); + Tensor program_shape_output(DT_STRING, TensorShape({1})); + program_shape_output.vec()(0) = program_shape.SerializeAsString(); + ctx->set_output(1, program_shape_output); } XRTCompileOp::~XRTCompileOp() = default; diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 257b054f16a49f3e14e1d76746c9fe0ba7fa8658..751329eefc33f3372335c805233dafabbf42bf36 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -64,28 +64,36 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// Looks up the input `key` in the compilation cache. -Status GetComputationCacheEntry( - XRTCompilationCache* cache, int64 key, - std::unique_ptr* entry) { - TF_RETURN_IF_ERROR(cache->Lookup(key, entry)); - return Status::OK(); -} - // Populates `inputs` with the input tensors to the computation. Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm, bool release_inputs, std::vector* input_tuples, std::vector* input_allocations, std::vector* input_pointers) { + std::vector input_uids; OpInputList arg_list; TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list)); - input_tuples->resize(arg_list.size()); - input_pointers->resize(arg_list.size()); + // Concatenate all input uids from list of scalars-or-vectors carrying them. for (int i = 0; i < arg_list.size(); ++i) { - TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape())); - int64 input_uid = arg_list[i].scalar()(); + const Tensor& arg = arg_list[i]; + if (TensorShapeUtils::IsScalar(arg.shape())) { + input_uids.push_back(arg.scalar()()); + } else { + TF_RET_CHECK(TensorShapeUtils::IsVector(arg.shape())); + auto arg_vec = arg.vec(); + const int64 num_elts = arg.shape().dim_size(0); + for (int i = 0; i < num_elts; ++i) { + input_uids.push_back(arg_vec(i)); + } + } + } + + // Retrieve allocations for the uids. + input_tuples->resize(input_uids.size()); + input_pointers->resize(input_uids.size()); + for (int i = 0; i < input_uids.size(); ++i) { + const int64 input_uid = input_uids[i]; TF_RETURN_IF_ERROR( XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i])); if (release_inputs) { @@ -98,7 +106,7 @@ Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm, XRTTupleAllocation* tuple = (*input_tuples)[i]; input_allocations->emplace_back(tuple->ToShapedBuffer()); } - for (int i = 0; i < arg_list.size(); ++i) { + for (int i = 0; i < input_uids.size(); ++i) { (*input_pointers)[i] = &(*input_allocations)[i]; } return Status::OK(); @@ -220,14 +228,35 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( shaped_buffer, device_ref.backend(), device_ref.device_ordinal(), &output_tuple)); - - Tensor* output_tensor; - TF_RETURN_IF_ERROR( - context->allocate_output(0, TensorShape({}), &output_tensor)); - int64 key; - TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); - output_tensor->scalar()() = key; - + if (config_proto.return_exploded_tuple() && + xla::ShapeUtil::IsTuple(output_tuple->on_device_shape())) { + int64 tuple_element_count = + xla::ShapeUtil::TupleElementCount(output_tuple->on_device_shape()); + Tensor* output_tensor; + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({tuple_element_count}), &output_tensor)); + + for (int64 i = 0; i < tuple_element_count; ++i) { + xla::ShapeIndex shape_index; + shape_index.push_back(i); + + XRTTupleAllocation* suballocation; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + output_tuple, shape_index, &suballocation, + /*alias_parent_allocation=*/false)); + int64 key; + TF_RETURN_IF_ERROR(suballocation->Intern(rm, &key)); + output_tensor->vec()(i) = key; + } + output_tuple->Unref(); + } else { + Tensor* output_tensor; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({}), &output_tensor)); + int64 key; + TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); + output_tensor->scalar()() = key; + } return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc index ffea592491d43788b876a51866dc8a6611e8c734..3258286c10665225aab917107ffa614459c53f3d 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc @@ -87,6 +87,19 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") .HostMemory("literal"), XRTReadLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_GPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); +REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") + .Device(DEVICE_XLA_CPU) + .HostMemory("handle") + .HostMemory("literal") + .HostMemory("output_handle"), + XRTWriteLiteralOp); + REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") .Device(DEVICE_XLA_GPU) .HostMemory("handle") diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 54b06558adcd8ef1f8f1bee52d210d558801afea..26a58fa42d8b730b365b11d2e5608e9945497763 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -393,6 +393,56 @@ class XRTReadLiteralOp : public OpKernel { } }; +// Op that writes a new literal value into device-resident memory. +template +class XRTWriteLiteralOp : public OpKernel { + public: + explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~XRTWriteLiteralOp() override = default; + XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; + XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; + + void Compute(OpKernelContext* ctx) override { + VLOG(1) << "XRTWriteLiteralOp::Compute"; + + const Tensor& handle_tensor = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), + errors::Internal("computation input should be an int64 scalar")); + int64 allocation_handle = handle_tensor.scalar()(); + + const Tensor& literal_info = ctx->input(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), + errors::Internal("literal input should be a string scalar")); + xla::LiteralProto literal_proto; + OP_REQUIRES(ctx, + literal_proto.ParseFromString(literal_info.scalar()()), + errors::InvalidArgument( + "Unable to parse allocation input to LiteralProto")); + xla::Literal literal; + OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); + + ResourceMgr* rm; + OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); + + XRTTupleAllocation* allocation; + OP_REQUIRES_OK( + ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation)); + core::ScopedUnref allocation_unref(allocation); + // We are guaranteed that the underlying device object won't be deleted out + // from under us, while the ScopedRef is live. + typename DeviceAccessor::ScopedRef device_ref; + OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( + ctx, allocation->device_ordinal(), &device_ref)); + OP_REQUIRES_OK(ctx, + allocation->WriteLiteral(device_ref.backend(), literal)); + + Tensor output(DT_INT64, TensorShape({})); + output.scalar()() = allocation_handle; + ctx->set_output(0, output); + } +}; + // Op that discards a handle to device memory. template class XRTReleaseAllocationOp : public OpKernel { diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 5cfc8711f9f4b4d54016156dd53471cadb34b581..7b3b50c69559f6003a108fdf6a1325dbdbaa80a6 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -23,7 +23,12 @@ namespace tensorflow { REGISTER_OP("XRTCompile") .Input("computation: string") .Output("handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Output("program_shape: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->UnknownShapeOfRank(1)); + return Status::OK(); + }) .Doc( R"( Reads a computation proto, compiles it, and places it in the global compilation diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index 40ec1b0ba9b336f5b6407c79c8d63e31219f9b84..4f59fccaf120e2358fa49518b030f0b0f42c322e 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -26,7 +26,16 @@ REGISTER_OP("XRTExecute") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") .Output("output_handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector input_handle_shapes; + TF_RETURN_IF_ERROR(c->input("input_handles", &input_handle_shapes)); + for (size_t i = 0; i < input_handle_shapes.size(); ++i) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR( + c->WithRankAtMost(input_handle_shapes[i], 1, &unused)); + } + return tensorflow::shape_inference::ScalarShape(c); + }) .Doc( R"( Runs a previously-compiled computation on a core. If diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc index 07d025ce343f229097b557d33ad41bf9612b0696..a3d63106fa14674a9f5887ccfd908ce17dbc6384 100644 --- a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc @@ -95,6 +95,20 @@ Copies an allocated tuple from device memory and returns it as a literal. 'literal' is a serialized xla::LiteralProto proto. )"); +REGISTER_OP("XRTWriteLiteral") + .Input("handle: int64") + .Input("literal: string") + .Output("output_handle: int64") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc( + R"( +Copies the input literal into the device memory pointed to by handle. +Returns the handle itself. + +'handle' is the id returned from the Op that produced the on-device allocation. +'literal' is a serialized xla::LiteralProto proto to be written to device memory. +)"); + REGISTER_OP("XRTReadLiteralAndRelease") .Input("handle: int64") .Output("literal: string") diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index b6dcfc4eb96316b5dad95a65b04d0ae69e4485f6..be44a3474acdeb9905c1d21b932fa0dd10b5a212 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -29,8 +29,11 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_server", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -49,7 +52,10 @@ tf_cc_test( name = "raw_api_test_cpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_CPU"], + args = [ + "--xla_test_device=XLA_CPU", + "--xla_platform=CPU", + ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_cpu_device", @@ -60,7 +66,10 @@ tf_cuda_cc_test( name = "raw_api_test_gpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_GPU"], + args = [ + "--xla_test_device=XLA_GPU", + "--xla_platform=GPU", + ], tags = tf_cuda_tests_tags(), deps = [ ":raw_api_test_lib", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index f590fbf0d9d85e6e8b041f6719ab6a14ec9e2191..abaa17e50e3f5e47a45f5a8a45fa2090d3efee39 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,10 +22,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" @@ -43,6 +46,7 @@ namespace tensorflow { namespace { string* xla_test_device_ptr; // initial value set in main() +string* xla_platform_ptr; // initial value set in main() string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; @@ -85,13 +89,20 @@ xla::LiteralProto FloatVector(absl::Span v) { return array.ToProto(); } +xla::LiteralProto FloatMatrix( + std::initializer_list> v, + const xla::Layout& layout) { + auto array = xla::LiteralUtil::CreateR2WithLayout(v, layout); + return array.ToProto(); +} + bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); bool equal = l_a == l_b; if (!equal) { - LOG(INFO) << "LiteralProtos don't match " << a.DebugString() + LOG(INFO) << "LiteralProtos don't match: " << a.DebugString() << " != " << b.DebugString(); } return equal; @@ -128,6 +139,31 @@ xla::XlaComputation AddAndScale() { return builder.Build().ValueOrDie(); } +xla::XlaComputation Dot() { + xla::XlaBuilder builder("Dot"); + auto p0 = xla::Parameter( + &builder, 0, + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0"); + auto p1 = xla::Parameter( + &builder, 1, + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1"); + xla::DotDimensionNumbers ddn; + ddn.add_lhs_contracting_dimensions(1); + ddn.add_rhs_contracting_dimensions(0); + xla::DotGeneral(p0, p1, ddn); + return builder.Build().ValueOrDie(); +} + +xla::XlaComputation AddS64() { + xla::XlaBuilder builder("AddS64"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}), + "P1"); + xla::Add(p0, p1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndTuple() { xla::XlaBuilder builder("AddAndTuple"); auto p0 = xla::Parameter(&builder, 0, @@ -139,12 +175,96 @@ xla::XlaComputation AddAndTuple() { return builder.Build().ValueOrDie(); } +xla::XlaComputation AddAndSubTuple() { + xla::XlaBuilder builder("AddAndSubTuple"); + auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P0"); + auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}), + "P1"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + xla::Tuple(&builder, {sum, sub}); + return builder.Build().ValueOrDie(); +} + void StoreComputationSnapshot(const xla::XlaComputation& computation, xla::HloSnapshot* dst) { auto snapshot = computation.Snapshot().ValueOrDie(); *dst = *snapshot; } +xla::ProgramShape XlaCompiledProgramShape( + const xla::XlaComputation& computation, + const xla::ProgramShape& input_program_shape) { + se::Platform* platform = + xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie(); + xla::LocalClient* client = + xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + xla::ExecutableBuildOptions exec_options; + exec_options.set_result_layout(input_program_shape.result()); + std::vector parameters_shapes; + for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) { + parameters_shapes.push_back(&input_program_shape.parameters(i)); + } + auto local_executable = + client->Compile(computation, parameters_shapes, exec_options) + .ValueOrDie(); + return local_executable->executable() + ->module() + .entry_computation() + ->ComputeProgramShape(); +} + +TEST(RawApiTest, AllocAndRewrite) { + xrt::XLAAllocation alloc; + alloc.set_device_ordinal(0); + *alloc.mutable_value() = + xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto(); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto value = + ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString()); + auto handle = ops::XRTAllocate(root, value); + auto read_back = ops::XRTReadLiteral(root, handle); + TF_ASSERT_OK(root.status()); + + tensorflow::ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, handle}, &outputs)); + EXPECT_EQ(outputs.size(), 2); + + int64 allocation_handle = outputs[1].scalar()(); + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); + outputs.clear(); + + xla::LiteralProto new_literal = + xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto(); + auto new_value = ops::Const(root.WithDevice("/device:CPU:0"), + new_literal.SerializeAsString()); + auto write_op = + ops::XRTWriteLiteral(root, Input(allocation_handle), new_value); + TF_ASSERT_OK(root.status()); + TF_EXPECT_OK(session.Run({write_op}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + EXPECT_EQ(allocation_handle, outputs[0].scalar()()); + outputs.clear(); + + auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run({read_after_write}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + xla::LiteralProto new_response; + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); + + auto release = + ops::XRTReleaseAllocationHandle(root, Input(allocation_handle)); + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release}, + &outputs)); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -317,9 +437,12 @@ TEST(RawApiTest, CompileAndExecute) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -338,27 +461,207 @@ TEST(RawApiTest, CompileAndExecute) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShapeProto program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); +} + +TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = FloatVector({1.0f, 2.0f}); + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = FloatVector({8.0f, 5.0f}); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"), + {Output(p0_handle), Output(p1_handle)}); + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShapeProto program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); +} + +TEST(RawApiTest, CompileWithXlaReturnShapes) { + xla::XlaBuilder builder("XrtXlaShapes"); + auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128}); + auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5}); + // Clear layouts to signal XLA we are ready to get whatever are coming out of + // the compilation process. + xla::LayoutUtil::ClearLayout(&input_shape); + xla::LayoutUtil::ClearLayout(&kernel_shape); + auto param_shape = + xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape}); + auto param = xla::Parameter(&builder, 0, param_shape, "param"); + auto input = xla::GetTupleElement(param, 0); + auto kernel = xla::GetTupleElement(param, 1); + xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build()); + + auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result(); + // Clear the result shape layout to tell XLA we are accepting whatever are + // coming out of the compilation process. + xla::LayoutUtil::ClearLayout(&result_shape); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = param_shape.ToProto(); + *shapes->mutable_result() = result_shape.ToProto(); + StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), + {c_handle.program_shape}, {release}, &outputs)); + + xla::ProgramShapeProto program_shape_proto; + EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); + xla::ProgramShape program_shape(program_shape_proto); + EXPECT_EQ(program_shape.parameters_size(), 1); + + VLOG(2) << "Param: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0)); + VLOG(2) << "Result: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); + + xla::ProgramShape xla_program_shape = + XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), + xla_program_shape.result().layout())); +} + +TEST(RawApiTest, DotGeneralWithLayoutTest) { + auto layout = xla::LayoutUtil::MakeLayout({0, 1}); + + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout); + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto(); + StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = + xla::LiteralUtil::CreateR2WithLayout({{18.0f}, {44.0f}}, layout); + + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } TEST(RawApiTest, CompileAndExecuteZeroArg) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); xrt::XRTExecutionConfig e; e.set_release_input_handles(true); @@ -371,7 +674,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { auto computation = ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); auto c_handle = ops::XRTCompile(root, computation); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, std::initializer_list({})); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); @@ -398,10 +701,13 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}); - *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::F32, {2})}); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; @@ -420,7 +726,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); @@ -437,15 +743,160 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(12.0f).ToProto(); + + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(3.0f).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}), + xla::ShapeUtil::MakeShape(xla::F32, {})}) + .ToProto(); + StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + e.set_return_exploded_tuple(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({result}, &outputs)); + EXPECT_EQ(outputs.size(), 1); + + auto handles_vec = outputs.front().vec(); + EXPECT_EQ(handles_vec.size(), 2); + + const float kResults[2] = {15.0f, 9.0f}; + for (int64 i = 0; i < handles_vec.size(); ++i) { + auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i))); + std::vector voutputs; + TF_EXPECT_OK(session.Run({read_back}, &voutputs)); + EXPECT_EQ(voutputs.size(), 1); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(kResults[i]); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + } +} + +TEST(RawApiTest, LeakCompilationReference) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) + .ToProto(); + StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); +} + +TEST(RawApiTest, CompileAndExecuteWithS64Argument) { + xrt::XLAAllocation p0; + p0.set_device_ordinal(0); + *p0.mutable_value() = xla::LiteralUtil::CreateR0(11031965).ToProto(); + xrt::XLAAllocation p1; + p1.set_device_ordinal(0); + *p1.mutable_value() = xla::LiteralUtil::CreateR0(4091934).ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto(); + StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = + ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = + ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR0(15123899); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShapeProto program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( + xla::Shape(program_shape.result()), xla::S64)); +} + } // namespace } // namespace tensorflow int main(int argc, char** argv) { tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU"); + tensorflow::xla_platform_ptr = new tensorflow::string("CPU"); std::vector flag_list = { tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr, "Tensorflow device type to use for test, e.g., XLA_CPU"), + tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr, + "The XLA platform to select for the device"), }; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 5678f0905ff5b8956e0811026e7450acba8815e9..378bb9246f27b8106310d565435404d7ac260a87 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -3,9 +3,28 @@ syntax = "proto3"; package xrt; import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/compiler/xla/xla.proto"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; +message DeviceAssignment { + message ComputationDevice { + message DeviceMeshCoordinates { + // The mesh coordinates for the device. Usually (X, Y, Core), in the order + // in which they are returned in the TopologyProto. + // X = value(0) + // Y = value(1) + // Core = value(2) + repeated int32 value = 1; + } + // As many replicas as there are in the replicated computation. + repeated DeviceMeshCoordinates replica_devices = 1; + } + // As many ComputationDevice as many there are computations (number + // of cores per replica). + repeated ComputationDevice computation_devices = 1; +} + // Options for an XLA compilation. message XLAComputationConfig { // The number of replicas the computation will be run on. If this is @@ -18,11 +37,18 @@ message XLAComputationConfig { tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; // The arg/result shapes for the whole computation. - xla.ProgramShape program_shape = 4; + xla.ProgramShapeProto program_shape = 4; // The arg/result shapes for each core of a model-parallel // computation. per_core_args_and_result_shapes is optional for a // single-core computation. - repeated xla.ProgramShape per_core_program_shape = 5; + repeated xla.ProgramShapeProto per_core_program_shape = 5; + // Describes how replicated computation instances should be assigned to + // devices. There are num_cores_per_replica computations, and each one will be + // sent and executed to the set of replica device numbers described in the + // DeviceAssignment proto. + DeviceAssignment device_assignment = 6; + // The debugging options to be passed to the XLA compilation process. + xla.DebugOptions debug_options = 7; } // Options and XLA computation for a compilation. @@ -75,4 +101,8 @@ message XRTExecutionConfig { bool release_input_handles = 5; // If true, release the handle to the computation after running. bool release_compilation_handle = 6; + // If set to true, and the result shape is a tuple, then instead of returning + // a single tuple allocation the execution will return a vector of + // allocations, one for each of the first-level elements of the result tuple. + bool return_exploded_tuple = 7; } diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc index 4844c7fb7106862dd42b3b3d07245350c9d2383c..d1405eae468492748ae88d842334a922dce272c6 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.cc +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc @@ -18,9 +18,19 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace { + +int64 get_uid() { + uint64 unsigned_rand = random::New64() & INT64_MAX; + return static_cast(unsigned_rand); +} + +} // namespace + const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache"; XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent, @@ -46,12 +56,17 @@ XRTCompilationCache::XRTCompilationCache(int max_number_of_entries) XRTCompilationCache::~XRTCompilationCache() { VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()"; + // A buggy client may be holding onto a reference, or a client might have + // crashed while holding onto a reference. In either case, discard all + // outstanding client references to avoid leaking storage. + for (const auto& entry : entries_by_uid_) { + while (!entry.second->RefCountIsOne()) { + entry.second->Unref(); + } + } while (!entries_by_last_use_.empty()) { MarkOldestEntryForEviction(); } - // By the time the cache is deleted all reference holders should have already - // been deleted, since they were holding references to the cache. So all - // entries should be gone at this point. CHECK_EQ(cache_.size(), 0); CHECK_EQ(entries_by_uid_.size(), 0); CHECK_EQ(cache_entries_, 0); @@ -148,7 +163,7 @@ XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry( CompiledSubgraph* entry = new CompiledSubgraph(); entry->parent = this; entry->key = key; - entry->uid = next_uid_++; + entry->uid = get_uid(); // Add the entry to the cache. Once the computation has been compiled, // UpdateEntryAfterCompilation will be called to potentially mark old entries // that don't fit any more for eviction. diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h index c505299a454506e2136e36fb26833c28ed0d47bc..c43d0fc47873abdc82ee937c155bebc346a05f17 100644 --- a/tensorflow/compiler/xrt/xrt_compilation_cache.h +++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h @@ -211,8 +211,6 @@ class XRTCompilationCache : public ResourceBase { const int max_cache_entries_; mutable absl::Mutex mu_; - // The uid to assign to the next new entry created. - int64 next_uid_ GUARDED_BY(mu_) = 0; // The total number of entries that are stored and not marked for eviction. int cache_entries_ GUARDED_BY(mu_) = 0; // The total number of entries that are marked for eviction. diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index d05a1e7dcbff440e0daf03bd25535c26d82b6a0b..31603e044d17baa3ae0ae583f61837811bb12495 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xrt/xrt_state.h" #include +#include #include #include #include @@ -33,6 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -40,14 +43,44 @@ namespace tensorflow { namespace { +class BufferAllocStats { + public: + struct Stats { + int64 count = 0; + int64 size = 0; + }; + + Stats ReportAlloc(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count += 1; + device_stats->size += msize; + return *device_stats; + } + + Stats ReportFree(int64 device, int64 msize) { + mutex_lock lock(lock_); + Stats* device_stats = &stats_[device]; + device_stats->count -= 1; + device_stats->size -= msize; + return *device_stats; + } + + private: + mutable mutex lock_; + std::map stats_; +}; + const char* kTupleContainer = "tuples"; -// Counter used to assign unique handles. -mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); -int64 _uid GUARDED_BY(_uid_mutex) = 0; int64 get_uid() { - mutex_lock l(_uid_mutex); - return _uid++; + uint64 unsigned_rand = random::New64() & INT64_MAX; + return static_cast(unsigned_rand); +} + +BufferAllocStats* GetAllocStats() { + static BufferAllocStats* stats = new BufferAllocStats(); + return stats; } Status AllocateScopedShapedBuffer( @@ -67,6 +100,9 @@ Status AllocateScopedShapedBuffer( // requests the host-shape sub-buffer at index i, that will correspond to the // right device-shape sub-buffer at the same index. xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape); + VLOG(3) << "Allocating literal buffer: host_shape=" + << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape=" + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape); // The ScopedShapedBuffer frees the buffers that have so far been allocated if // it goes out of scope. That's useful if we return early as the result of an @@ -99,9 +135,19 @@ XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation, xla::DeviceMemoryAllocator* allocator) : allocation_(allocation), device_ordinal_(device_ordinal), - allocator_(allocator) {} + allocator_(allocator) { + if (VLOG_IS_ON(2)) { + auto stats = + GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size()); + LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_ + << " count=" << stats.count << " size=" << stats.size; + } +} XRTBufferAllocation::~XRTBufferAllocation() { + if (VLOG_IS_ON(2)) { + GetAllocStats()->ReportFree(device_ordinal_, allocation_.size()); + } // Deallocate explicitly allows allocation_ to be null. Status s = allocator_->Deallocate(device_ordinal_, allocation_); // Nothing to do but check fail here if memory datastructures are corrupted. @@ -182,6 +228,20 @@ Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, return Status::OK(); } +Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend, + const xla::Literal& literal) { + if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) { + return errors::InvalidArgument( + "New literal shape not matching the existing one: literal=", + xla::ShapeUtil::HumanStringWithLayout(literal.shape()), + " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape())); + } + auto transfer_manager = backend->transfer_manager(); + TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal())); + return transfer_manager->TransferLiteralToDevice(stream.get(), literal, + ToShapedBuffer()); +} + void XRTTupleAllocation::DiscardAllocation( const xla::ShapeIndex& buffer_index) { buffers_.element(buffer_index)->DiscardAllocation(); diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 73b5584e38f781343fe6793af7ad28232fbfc184..3664c0cd4e6ad26945ae1012208fdb006164a066 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -137,6 +137,9 @@ class XRTTupleAllocation : public ResourceBase { Status ToLiteral(xla::Backend* backend, int device_ordinal, xla::Literal* literal); + // Write a new literal value to the allocation. + Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); + // True if none of the buffers in the allocation are aliased by any other live // handle. bool IsExclusiveOwner(); diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ef8bedc7324696cd255c72a851f0f2410e03848 --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xrt/xrt_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { + +bool DebugOptionsPassThroughEnabled() { + const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH"); + bool enabled = + env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + if (enabled) { + LOG(WARNING) << "Passing through XLA debug options!"; + } else { + LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options " + "will be retained"; + } + return enabled; +} + +string SafeDebugPath(const string& path) { + if (path.empty() || path.compare(0, 5, "gs://") == 0 || + path.compare(0, 11, "bigstore://") == 0) { + return path; + } + LOG(WARNING) << "Invalid config path (will be dropped): " << path; + return string(); +} + +} // namespace + +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) { + static const bool options_passthrough = DebugOptionsPassThroughEnabled(); + if (options_passthrough) { + return ref_options; + } + xla::DebugOptions options = xla::GetDebugOptionsFromFlags(); + options.set_xla_generate_hlo_text_to( + SafeDebugPath(ref_options.xla_generate_hlo_text_to())); + options.set_xla_dump_optimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to())); + options.set_xla_dump_computations_to( + SafeDebugPath(ref_options.xla_dump_computations_to())); + options.set_xla_dump_executions_to( + SafeDebugPath(ref_options.xla_dump_executions_to())); + for (auto& pass : ref_options.xla_disable_hlo_passes()) { + options.add_xla_disable_hlo_passes(pass); + } + options.set_xla_dump_unoptimized_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to())); + options.set_xla_dump_per_pass_hlo_proto_to( + SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to())); + return options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h new file mode 100644 index 0000000000000000000000000000000000000000..d9c05a7f3406313f99ae214d67b34e8e7de8be3e --- /dev/null +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility functions in support of the XRT API. + +#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ +#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace tensorflow { + +// Filters the debug options provided as argument according to the value of the +// TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is +// set to "1" or "true", the debug options will be returned as is. Otherwise +// only a subset of them will be set in the returned ones, and all the paths +// contained in it, will be limited to gs:// and bigstore:// ones. +xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index fa06d351d4e64bfc2fc5e64c81c810185600000a..832db0f4ab46911e067d17b4a125706c276cf798 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -72,7 +72,6 @@ py_library( "//tensorflow/contrib/metrics:metrics_py", "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", - "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", @@ -113,22 +112,52 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ - "//tensorflow/contrib/bigtable", - "//tensorflow/contrib/cloud:cloud_py", - "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols "//tensorflow/contrib/kafka", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ "//tensorflow/contrib/kinesis", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ + "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/tensorrt:init_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ], }) + select({ - "//tensorflow:with_ignite_support": [ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_gcp_support": [], + "//conditions:default": [ + "//tensorflow/contrib/bigtable", + "//tensorflow/contrib/cloud:cloud_py", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_ignite_support": [], + "//conditions:default": [ "//tensorflow/contrib/ignite", ], - "//conditions:default": [], }), ) @@ -149,17 +178,27 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ - "//tensorflow/contrib/nccl:nccl_kernels", - ]) + select({ + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_kernels", + ], + }) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ "//tensorflow/contrib/kinesis:dataset_kernels", - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", ], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + ]), ) cc_library( @@ -173,7 +212,6 @@ cc_library( "//tensorflow/contrib/hadoop:dataset_ops_op_lib", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", - "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", "//tensorflow/contrib/rnn:all_ops", "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", @@ -183,17 +221,33 @@ cc_library( "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", ] + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], + "//tensorflow:no_kafka_support": [], "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_ops_op_lib", - "//tensorflow/contrib/kinesis:dataset_ops_op_lib", - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", ], }) + select({ - "//tensorflow:with_ignite_support": [ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_aws_support": [], + "//conditions:default": [ + "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + ], + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + ]) + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//tensorflow:no_ignite_support": [], + "//conditions:default": [ "//tensorflow/contrib/ignite:dataset_ops_op_lib", ], - "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index e71b0e0ae33f9c2dd48643e557447372bc67b3e3..4f1a2a5693235183c8f486817b82c8c81fa389ec 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,14 +21,6 @@ from __future__ import print_function import os -from tensorflow.python.tools import component_api_helper -component_api_helper.package_hook( - parent_package_str=( - "tensorflow.contrib"), - child_package_str=( - "tensorflow_estimator.contrib.estimator")) -del component_api_helper - # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching @@ -70,7 +62,6 @@ from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning -from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import periodic_resample diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 881808a98bfd688c2efaa8beb5b8f11a2527fee8..f6c6560c1c354ed8a36b98b1f564835eb9958e55 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -9,8 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_py_test") - py_library( name = "all_reduce_py", srcs = ["__init__.py"], @@ -29,29 +27,6 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/nccl:nccl_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - ], -) - -tf_py_test( - name = "all_reduce_test", - srcs = ["python/all_reduce_test.py"], - additional_deps = [ - ":all_reduce", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", - "//tensorflow/python:state_ops", + "//tensorflow/python/distribute:all_reduce", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 3b539734a236804026826a8117d9c668c0dd089a..238cdaf8a79812df3f043d9d070bbcfd443f6e1e 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,842 +18,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import math - -from tensorflow.contrib import nccl -from tensorflow.python.framework import device as device_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -def _flatten_tensors(tensors): - """Check tensors for isomorphism and flatten. - - Args: - tensors: list of T `tf.Tensor` which must all have the same shape. - - Returns: - tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors - shape: the original shape of each element of input tensors - - Raises: - ValueError: tensors are empty or non-isomorphic or have unknown shape. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - for tensor in tensors: - shape = shape.merge_with(tensor.shape) - if not shape.is_fully_defined(): - raise ValueError("Tensors must have statically known shape.") - if len(shape) != 1: - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, [-1])) - tensors = reshaped - return tensors, shape - - -def _reshape_tensors(tensors, shape): - """Reshape tensors flattened by _flatten_tensors. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - shape: list of integers describing the desired shape. Product of - the elements must equal the length of each tensor. - - Returns: - list of T `tf.Tensor` which are the reshaped inputs. - """ - reshaped = [] - for t in tensors: - with ops.colocate_with(t): - reshaped.append(array_ops.reshape(t, shape)) - return reshaped - - -def _padded_split(tensor, pieces): - """Like split for 1D tensors but pads-out case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - thin input tensor, in order. The final tensor may - be zero-padded on the end to make its size equal to those of all - of the other tensors. - - Raises: - ValueError: The input tensor is not 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape[0].value - with ops.colocate_with(tensor): - if tensor_len % pieces != 0: - # pad to an even length - chunk_size = 1 + tensor_len // pieces - if pieces > tensor_len: - # This is an edge case that should not come up in practice, - # i.e. a different reduction algorithm would be better, - # but we'll make it work just for completeness. - pad_len = pieces - tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - elif (pieces - 1) * chunk_size >= tensor_len: - # Another edge case of limited real interest. - pad_len = (pieces * chunk_size) % tensor_len - extended_whole = array_ops.concat( - [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - parts = array_ops.split(extended_whole, pieces) - return parts, pad_len - else: - last_chunk_size = tensor_len - (pieces - 1) * chunk_size - pad_len = chunk_size - last_chunk_size - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - parts = array_ops.split(tensor, piece_lens) - parts[-1] = array_ops.concat( - [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) - return parts, pad_len - else: - return array_ops.split(tensor, pieces), 0 - - -def _strip_padding(tensors, pad_len): - """Strip the suffix padding added by _padded_split. - - Args: - tensors: list of T `tf.Tensor` of identical length 1D tensors. - pad_len: number of elements to be stripped from the end of each tensor. - - Returns: - list of T `tf.Tensor` which are the stripped inputs. - - Raises: - ValueError: tensors must be a non-empty list of 1D tensors, and - each must be longer than pad_len. - """ - if not tensors: - raise ValueError("tensors cannot be empty") - shape = tensors[0].shape - if len(shape) > 1: - raise ValueError("tensors must be 1D") - prefix_len = int(shape[0] - pad_len) - if prefix_len < 0: - raise ValueError("pad_len longer than tensor") - stripped = [] - for t in tensors: - with ops.colocate_with(t): - stripped.append(array_ops.slice(t, [0], [prefix_len])) - return stripped - - -def _ragged_split(tensor, pieces): - """Like split for 1D tensors but allows case where len % pieces != 0. - - Args: - tensor: T `tf.Tensor` that must be 1D. - pieces: a positive integer specifying the number of pieces into which - tensor should be split. - - Returns: - list of T `tf.Tensor` of length pieces, which hold the values of - the input tensor, in order. The final tensor may be shorter - than the others, which will all be of equal length. - - Raises: - ValueError: input tensor must be 1D. - """ - shape = tensor.shape - if 1 != len(shape): - raise ValueError("input tensor must be 1D") - tensor_len = shape[0].value - chunk_size = tensor_len // pieces - with ops.colocate_with(tensor): - if tensor_len != (pieces * chunk_size): - # last piece will be short - assert pieces > 1 - last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) - assert last_chunk_size > 0 - piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] - return array_ops.split(tensor, piece_lens) - else: - return array_ops.split(tensor, pieces) - - -def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for each subchunk. - - In the basic ring reduction algorithm there are size(T)/num_devices - data chunks and each device process one chunk per tick, i.e. sending - one chunk and receiving one chunk. The idea of subchunking is that - each device processes num_subchunks smaller data regions per tick, - and the ring rank permutation is different for each subchunk index - so that a device is potentially sending to and receiving from - num_subchunks different other devices at each tick. Where multiple - independent data channels exist between devices, this strategy - supplies a method of using them in parallel. - - Args: - num_workers: number of worker tasks - num_subchunks: number of subchunks into which to divide each per-GPU chunk. - gpu_perm: an array of integers in [0, num_gpus-1] giving the default - ring order of GPUs at each worker. Other permutations will be generated - by rotating this array and splicing together per-worker instances. - - Raises: - ValueError: the number of subchunks may not exceed the number of GPUs. - - Returns: - pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - preceding device in the permutation for that subchunk. The - device index of GPU i at worker j is i + (j * num_gpus). - rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to - local rank of device d in the permutation for that subchunk. - """ - num_gpus = len(gpu_perm) - devices = num_workers * num_gpus - if devices == 0: - return [], [] - if num_subchunks > num_gpus: - raise ValueError( - "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) - rotation_interval = max(1, int(num_gpus / num_subchunks)) - perms_by_s = [] - for s in range(0, num_subchunks): - full_order = [] - offset = s * rotation_interval - for w in range(0, num_workers): - default_order = [(w * num_gpus) + i for i in gpu_perm] - dev_order = default_order[offset:] + default_order[:offset] - full_order += dev_order - perms_by_s.append(full_order) - pred_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - rank_by_s_d = [[-1 for d in range(0, devices)] - for s in range(0, num_subchunks)] - for s in range(0, num_subchunks): - for d in range(0, devices): - for t in range(0, devices): - if d == perms_by_s[s][t]: - rank_by_s_d[s][d] = t - pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] - break - return (pred_by_s_d, rank_by_s_d) - - -def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, - gpu_perm, red_op, un_op=None): - """Construct a subgraph performing a ring-style all-reduce of input_tensors. - - Args: - input_tensors: a list of T `tf.Tensor` objects, which must all - have the same shape and type. - num_workers: number of worker tasks spanned by input_tensors. - num_subchunks: number of subchunks each device should process in one tick. - gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at - each worker. All workers must have the same number of - GPUs with the same rank ordering. If NVLINK is available, this should - be a ring order supported by NVLINK edges. - red_op: a binary operator for elementwise reduction. - un_op: an optional unary operator to apply to fully reduced values. - - Raises: - ValueError: empty input_tensors or they don't all have same - size. - - Returns: - a list of T `tf.Tensor` identical sum-reductions of input_tensors. - """ - if len(input_tensors) < 2: - raise ValueError("input_tensors must be length 2 or longer") - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - (pred_by_s_d, rank_by_s_d) = _ring_permutations( - num_workers, num_subchunks, gpu_perm) - chunks_by_dev, pad_len = _build_ring_gather( - input_tensors, devices, - num_subchunks, pred_by_s_d, rank_by_s_d, red_op) - if un_op: - chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) - output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev) - if pad_len > 0: - output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_ring_gather(input_tensors, devices, num_subchunks, - pred_by_s_d, rank_by_s_d, red_op): - """Construct a subgraph for the first (reduction) pass of ring all-reduce. - - Args: - input_tensors: a list of T `tf.Tensor` 1D input tensors of same - shape and type. - devices: array of device name strings - num_subchunks: number of subchunks each device should process in one tick. - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - red_op: a binary operator for elementwise reduction - - Raises: - ValueError: tensors must all be one dimensional. - - Returns: - list of list of T `tf.Tensor` of (partially) reduced values where - exactly num_subchunks chunks at each device are fully reduced. - """ - num_devices = len(input_tensors) - if num_devices == 0: - return [] - if num_devices == 1: - return input_tensors - shape = input_tensors[0].shape - if 1 != len(shape): - raise ValueError("input tensors must be 1D") - num_chunks = num_devices * num_subchunks - num_ticks = num_devices - 1 - # Initialize chunks_by_dev with splits of the input tensors. - chunks_by_dev = [] - split_pad_len = 0 - for d in range(0, num_devices): - with ops.device(devices[d]): - splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) - chunks_by_dev.append(splits) - # Reduction phase - for tick in range(0, num_ticks): - # One new partial reduction for every chunk - new_partial_reductions = [None for _ in range(0, num_chunks)] - # Compute reductions with respect to last tick's values - for d in range(0, num_devices): - with ops.device(devices[d]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - new_partial_reductions[chunk_index] = red_op( - chunks_by_dev[pred_dev][chunk_index], - chunks_by_dev[d][chunk_index]) - # Update chunks_by_dev with the new values at the end of the tick. - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (2 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] - return chunks_by_dev, split_pad_len - - -def _apply_unary_to_chunks(f, chunks_by_dev): - """Apply a unary op to each tensor in chunks_by_dev, on same device. - - Args: - f: a unary function over T `tf.Tensor`. - chunks_by_dev: list of lists of T `tf.Tensor`. - - Returns: - new list of lists of T `tf.Tensor` with the same structure as - chunks_by_dev containing the derived tensors. - """ - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append([f(t) for t in x]) - return output - - -def _build_ring_scatter(pred_by_s_d, rank_by_s_d, - chunks_by_dev): - """Construct subgraph for second (scatter) pass of ring all-reduce. - - Args: - pred_by_s_d: as produced by _ring_permutations - rank_by_s_d: as produced by _ring_permutations - chunks_by_dev: list of list of T `tf.Tensor` indexed by ints - (device, chunk) - - Raises: - ValueError: chunks_by_dev is not well-formed - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device corresponding to the outer dimension of chunks_by_dev. - """ - num_devices = len(chunks_by_dev) - num_chunks = len(chunks_by_dev[0]) - if 0 != num_chunks % num_devices: - raise ValueError( - "Expect number of chunks per device to be divisible by num_devices") - num_subchunks = int(num_chunks / num_devices) - num_ticks = num_devices - 1 - for tick in range(0, num_ticks): - passed_values = [None for _ in range(0, num_chunks)] - for d in range(0, num_devices): - with ops.colocate_with(chunks_by_dev[d][0]): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - pred_dev = pred_by_s_d[s][d] - chunk_index = (seg_index * num_subchunks) + s - passed_values[chunk_index] = array_ops.identity( - chunks_by_dev[pred_dev][chunk_index]) - for d in range(0, num_devices): - for s in range(0, num_subchunks): - rank = rank_by_s_d[s][d] - seg_index = (rank + num_devices - (1 + tick)) % num_devices - chunk_index = (seg_index * num_subchunks) + s - chunks_by_dev[d][chunk_index] = passed_values[chunk_index] - # Join chunks at each device. - output = [] - for x in chunks_by_dev: - with ops.colocate_with(x[0]): - output.append(array_ops.concat(x, 0)) - return output - - -def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): - """Construct a subgraph for recursive halving-doubling all-reduce. - - The recursive halving-doubling algorithm is described in - http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf - - The concept is to arrange the participating n devices in - a linear sequence where devices exchange data pairwise - with one other device in each round. During the gather - phase there are lg(n) rounds where devices exchange - increasingly smaller sub-tensors with another device - at increasingly greater distances, until at the top - each device has 1/n of the fully reduced values. During the - scatter phase each device exchanges its fully reduced - sub-tensor (which doubles in length at each round) - with one other device at increasingly smaller distances - until each device has all of the fully reduced values. - - Note: this preliminary version requires that len(input_tensors) be a - power of 2. TODO(tucker): relax this restriction. Also, the - number of elements in each tensor must be divisible by 2^h where h - is the number of hops in each phase. This will also be relaxed in - the future with edge-case specific logic. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - red_op: a binary elementwise reduction Op. - un_op: an optional unary elementwise Op to apply to reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors, one - at each device of input_tensors. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - devices = [t.device for t in input_tensors] - input_tensors, shape = _flatten_tensors(input_tensors) - reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) - if un_op: - reduced_shards = [un_op(t) for t in reduced_shards] - output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_recursive_hd_gather(input_tensors, devices, red_op): - """Construct the gather phase of recursive halving-doubling all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` to be elementwise reduced. - devices: a list of strings naming the devices hosting input_tensors, - which will also be used to host the (partial) reduction values. - red_op: a binary elementwise reduction Op. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensor shards. - - Raises: - ValueError: num_devices not a power of 2, or tensor len not divisible - by 2 the proper number of times. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - if num_devices != (2 ** num_hops): - raise ValueError("num_devices must be a power of 2") - chunks = input_tensors - for h in range(0, num_hops): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_dev = devices[d] - right_dev = devices[d + span] - left_split = array_ops.split(chunks[d], 2) - right_split = array_ops.split(chunks[d+span], 2) - with ops.device(left_dev): - new_chunks[d] = red_op(left_split[0], right_split[0]) - with ops.device(right_dev): - new_chunks[d + span] = red_op(left_split[1], right_split[1]) - chunks = new_chunks - return chunks - - -def _build_recursive_hd_scatter(input_tensors, devices): - """Construct the scatter phase of recursive halving-doublng all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` that are fully-reduced shards. - devices: a list of strings naming the devices on which the reconstituted - full tensors should be placed. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - num_devices = len(devices) - num_hops = int(math.log(num_devices, 2)) - assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" - chunks = input_tensors - for h in reversed(range(0, num_hops)): - span = 2 ** h - group_size = span * 2 - new_chunks = [[] for _ in devices] - for d in range(0, num_devices): - if (d % group_size) >= (group_size / 2): - # skip right half of a pair - continue - left_idx = d - right_idx = d + span - left_dev = devices[left_idx] - right_dev = devices[right_idx] - with ops.device(left_dev): - new_chunks[left_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - with ops.device(right_dev): - new_chunks[right_idx] = array_ops.concat([chunks[left_idx], - chunks[right_idx]], 0) - chunks = new_chunks - return chunks - - -def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): - """Construct a subgraph for shuffle all-reduce. - - Shuffle reduce is essentially the algorithm implemented when using - parameter servers. Suppose tensor length is n, there are d devices - and g gather shards. Each device sends a n/g length sub-tensor to - each gather shard. The gather shards perform a reduction across d - fragments, then broadcast the result back to each device. The - devices then join the g fully reduced fragments they receive from - the shards. The gather shards could perform d-1 pairwise - reductions, or one d-way reduction. The first is better where - reduction Op time is low compared to transmission time, the second - better in the other case. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: an n-array elementwise reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced tensors. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - dst_devices = [t.device for t in input_tensors] - reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, - red_op, un_op) - output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): - """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. - - Args: - input_tensors: list of T @(tf.Tensor} values to be reduced. - gather_devices: list of names of devices on which reduction shards - should be placed. - red_op: the binary reduction Op - un_op: optional elementwise unary Op to be applied to fully-reduced values. - - Returns: - list of T `tf.Tensor` which are the fully reduced shards. - - Raises: - ValueError: inputs not well-formed. - """ - num_source_devices = len(input_tensors) - num_gather_devices = len(gather_devices) - shape = input_tensors[0].shape - if len(shape) != 1: - raise ValueError("input_tensors must be 1D") - shards_by_source = [] - for d in range(0, num_source_devices): - with ops.colocate_with(input_tensors[d]): - shards_by_source.append( - _ragged_split(input_tensors[d], num_gather_devices)) - reduced_shards = [] - for d in range(0, num_gather_devices): - with ops.device(gather_devices[d]): - values = [s[d] for s in shards_by_source] - red_shard = red_op(values) - if un_op: - red_shard = un_op(red_shard) - reduced_shards.append(red_shard) - return reduced_shards - - -def _build_shuffle_scatter(reduced_shards, dst_devices): - """Build the scatter phase of shuffle all-reduce. - - Args: - reduced_shards: list of T @(tf.Tensor} fully reduced shards - dst_devices: list of names of devices at which the fully-reduced value - should be reconstituted. - - Returns: - list of T `tf.Tensor` scattered tensors. - """ - num_devices = len(dst_devices) - out_tensors = [] - for d in range(0, num_devices): - with ops.device(dst_devices[d]): - out_tensors.append(array_ops.concat(reduced_shards, 0)) - return out_tensors - - -def _split_by_task(devices, values): - """Partition devices and values by common task. - - Args: - devices: list of device name strings - values: list of T `tf.tensor` of same length as devices. - - Returns: - (per_task_devices, per_task_values) where both values are - lists of lists with isomorphic structure: the outer list is - indexed by task, and the inner list has length of the number - of values belonging to that task. per_task_devices contains - the specific devices to which the values are local, and - per_task_values contains the corresponding values. - - Raises: - ValueError: devices must be same length as values. - """ - num_devices = len(devices) - if num_devices != len(values): - raise ValueError("len(devices) must equal len(values)") - per_task_devices = collections.OrderedDict() - per_task_values = collections.OrderedDict() - for d in range(num_devices): - d_spec = device_lib.DeviceSpec.from_string(devices[d]) - if not hasattr(d_spec, "task") or d_spec.task is None: - assert False, "failed to parse device %s" % devices[d] - index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) - if index not in per_task_devices: - per_task_devices[index] = [] - per_task_values[index] = [] - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - - return (list(per_task_devices.values()), list(per_task_values.values())) - - -def build_nccl_all_reduce(input_tensors, red_op, un_op=None): - """Build a subgraph that does one full all-reduce, using NCCL. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. Must be one of - {tf.add} - un_op: optional unary elementwise Op to apply to fully-reduce values. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: red_op not supported. - """ - if red_op == math_ops.add: - output_tensors = nccl.all_sum(input_tensors) - else: - raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) - if un_op: - un_op_wrapped = [] - for t in output_tensors: - with ops.colocate_with(t): - un_op_wrapped.append(un_op(t)) - output_tensors = un_op_wrapped - return output_tensors - - -def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): - """Construct a subgraph for NCCL hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [None for w in range(0, num_workers)] - up_devices = up_values[:] - down_values = up_values[:] - # First stage: reduce within each worker using NCCL - for w in range(0, num_workers): - worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) - # NOTE: these reductions will not run to completion unless - # every output value is used. Since we only need one, we - # need to put control dependencies on the rest. - with ops.control_dependencies(worker_values): - with ops.device(worker_values[0].device): - up_values[w] = array_ops.identity(worker_values[0]) - up_devices[w] = per_worker_devices[w][0] - # Second stage: Apply upper_level_f to reduce across first device at - # each worker - level_2_output = upper_level_f(up_values) - # Third stage: propagate within each worker using NCCL Broadcast - for w in range(0, num_workers): - dst_tensors = [] - with ops.device(per_worker_devices[w][0]): - broadcast_src = nccl.broadcast(array_ops.identity(level_2_output[w])) - for d in per_worker_devices[w]: - with ops.device(d): - dst_tensors.append(array_ops.identity(broadcast_src)) - down_values[w] = dst_tensors - output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tensors has more than one element apply red_f, else apply un_op.""" - if len(input_tensors) > 1: - return red_f(input_tensors) - else: - if not un_op: - return input_tensors - output_tensors = [] - for t in input_tensors: - with ops.colocate_with(t): - output_tensors.append(un_op(t)) - return output_tensors - - -def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Ring across workers.""" - def upper_builder(y): - return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) - def upper_level_f(x): - return _reduce_non_singleton(x, upper_builder, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): - """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" - upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) - return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) - - -def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, - shuffle_red_op, un_op=None): - """Construct hybrid of NCCL within workers, Shuffle across workers.""" - upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices, - shuffle_red_op, un_op) - return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) - - -def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): - """Construct a subgraph for Shuffle hybrid all-reduce. - - Args: - input_tensors: list of T `tf.Tensor` of same-shape and type values to - be reduced. - gather_devices: list of device names on which to host gather shards. - red_op: binary elementwise reduction operator. - upper_level_f: function for reducing one value per worker, across - workers. - - Returns: - list of T `tf.Tensor` of reduced values. - - Raises: - ValueError: inputs not well-formed. - """ - input_tensors, shape = _flatten_tensors(input_tensors) - # First stage, reduce across each worker using gather_devices. - devices = [t.device for t in input_tensors] - per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) - num_workers = len(per_worker_devices) - up_values = [] - if len(gather_devices) != num_workers: - raise ValueError("For shuffle hybrid, gather_devices must contain one " - "device per worker. ") - for w in range(0, num_workers): - reduced_shards = _build_shuffle_gather( - per_worker_values[w], [gather_devices[w]], red_op) - up_values.append(reduced_shards[0]) - # Second stage, apply upper_level_f. - level_2_output = upper_level_f(up_values) - # Third stage, apply shuffle scatter at each worker. - output_tensors = [] - for w in range(0, num_workers): - output_tensors += _build_shuffle_scatter( - [level_2_output[w]], per_worker_devices[w]) - if len(shape) != 1: - output_tensors = _reshape_tensors(output_tensors, shape) - return output_tensors - - -def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Ring across workers.""" - def upper_builder(tensors): - return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, gather_devices, red_n_op, upper_level_f) - - -def build_shuffle_then_shuffle(input_tensors, first_gather_devices, - second_gather_devices, red_op, un_op=None): - """Construct hybrid of Shuffle within workers, Shuffle across workers.""" - def upper_builder(tensors): - return build_shuffle_all_reduce(tensors, second_gather_devices, - red_op, un_op) - def upper_level_f(tensors): - return _reduce_non_singleton(tensors, upper_builder, un_op) - return _build_shuffle_hybrid( - input_tensors, first_gather_devices, red_op, upper_level_f) +# pylint: disable=unused-import,wildcard-import +from tensorflow.python.distribute.all_reduce import * diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index db37bcf73d144eb81c32a461a276d10be7e2d193..27f8ac21323e6eb21a80dfab4d2239738c2fcf1e 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -52,6 +52,7 @@ Then, to build the native TF library: bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cxxopt=-std=c++11 \ --cpu=armeabi-v7a ``` diff --git a/tensorflow/contrib/android/cmake/build.gradle b/tensorflow/contrib/android/cmake/build.gradle index 17a57b99fd6c9efc09bda0ce1249b1f51bd5af5c..ddec08894f34f96b080610f1d27a6a436f7ffa91 100644 --- a/tensorflow/contrib/android/cmake/build.gradle +++ b/tensorflow/contrib/android/cmake/build.gradle @@ -22,8 +22,8 @@ android { } externalNativeBuild { cmake { - arguments '-DANDROID_TOOLCHAIN=gcc', - '-DANDROID_STL=gnustl_static' + arguments '-DANDROID_TOOLCHAIN=clang', + '-DANDROID_STL=c++_static' } } } @@ -70,7 +70,7 @@ if (ndkDir == null || ndkDir == "") { ndkDir = System.getenv('ANDROID_NDK_HOME') } -if(! Os.isFamily(Os.FAMILY_WINDOWS)) { +if (!Os.isFamily(Os.FAMILY_WINDOWS)) { // This script is for non-Windows OS. For Windows OS, MANUALLY build // (or copy the built) libs/headers to the // ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen diff --git a/tensorflow/contrib/autograph/examples/benchmarks/BUILD b/tensorflow/contrib/autograph/examples/benchmarks/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6d2d70c99b4cc804f2c8bf57afdc8c11f1f73516 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark") + +py_library( + name = "benchmark_base", + srcs = [ + "benchmark_base.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "cartpole_benchmark", + size = "enormous", + srcs = ["cartpole_benchmark.py"], + tags = [ + "local", + "manual", + "no_oss", + "notap", + "nozapfhahn", + ], + deps = [ + ":benchmark_base", + # Note: required gym dependency may need to be added here. + ], +) + +tf_py_logged_benchmark( + name = "cartpole_logged_benchmark", + target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark", +) diff --git a/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py new file mode 100644 index 0000000000000000000000000000000000000000..93c694849c4dc3faca71e7f9d8614649a7784f99 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/benchmark_base.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Common benchmarking code. + +See https://www.tensorflow.org/community/benchmarks for usage. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +import tensorflow as tf + + +class ReportingBenchmark(tf.test.Benchmark): + """Base class for a benchmark that reports general performance metrics. + + Subclasses only need to call one of the _profile methods, and optionally + report_results. + """ + + def time_execution(self, name, target, iters, warm_up_iters=5): + for _ in range(warm_up_iters): + target() + + all_times = [] + for _ in range(iters): + iter_time = time.time() + target() + all_times.append(time.time() - iter_time) + + avg_time = np.average(all_times) + + extras = dict() + extras['all_times'] = all_times + + if isinstance(name, tuple): + extras['name'] = name + name = '_'.join(str(piece) for piece in name) + + self.report_benchmark( + iters=iters, wall_time=avg_time, name=name, extras=extras) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..4f553be58e94f11e45f0697558348fbbd26bfb91 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/benchmarks/cartpole_benchmark.py @@ -0,0 +1,492 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A basic RL cartpole benchmark. + +The RL model uses the OpenAI Gym environment to train a simple network using +the policy gradients method. The training scales the gradients for each step +by the episode's cumulative discounted reward and averages these gradients over +a fixed number of games before applying the optimization step. + +For benchmarking purposes, we replace the OpenAI Gym environment to a fake +that returns random actions and rewards and never ends the episode. This way +the benchmarks compare the same amount of computation at each step. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import eager +from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base +from tensorflow.python import autograph as ag +from tensorflow.python.eager import context + +# +# AutoGraph implementation +# + + +@ag.convert() +def graph_append_discounted_rewards(destination, rewards, discount_rate): + """Discounts episode rewards and appends them to destination.""" + ag.set_element_type(rewards, tf.float32) + + cdr = 0.0 + reverse_discounted = [] + ag.set_element_type(reverse_discounted, tf.float32) + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + cdr.set_shape(()) + reverse_discounted.append(cdr) + + retval = destination + # Note: AutoGraph doesn't yet support reversed() so we use a loop instead. + for i in range(len(reverse_discounted) - 1, -1, -1): + retval.append(reverse_discounted[i]) + + return retval + + +class GraphPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(GraphPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + # TODO(mdan): Move this method out of the class. + @ag.convert() + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + var_list = tf.trainable_variables() + grad_list = [ + tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list + ] + + step_counts = [] + discounted_rewards = [] + ag.set_element_type(discounted_rewards, tf.float32) + ag.set_element_type(step_counts, tf.int32) + + # Note: we use a shared object, cart_pole_env here. Because calls to the + # object's method are made through py_func, TensorFlow cannot detect its + # data dependencies. Hence we must manually synchronize access to it + # and ensure the control dependencies are set in such a way that + # calls to reset(), take_one_step, etc. are made in the correct order. + sync_counter = tf.constant(0) + + for _ in tf.range(num_games): + with tf.control_dependencies([sync_counter]): + obs = cart_pole_env.reset() + with tf.control_dependencies([obs]): + sync_counter += 1 + + game_rewards = [] + ag.set_element_type(game_rewards, tf.float32) + + for step in tf.range(max_steps_per_game): + logits, actions = self(obs) # pylint:disable=not-callable + logits = tf.reshape(logits, ()) + actions = tf.reshape(actions, ()) + + labels = 1.0 - tf.cast(actions, tf.float32) + loss = tf.nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + grads = tf.gradients(loss, var_list) + + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + with tf.control_dependencies([sync_counter]): + obs, reward, done = cart_pole_env.step(actions) + with tf.control_dependencies([obs]): + sync_counter += 1 + obs = tf.reshape(obs, (1, 4)) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = graph_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = ag.stack(discounted_rewards) + discounted_rewards.set_shape((None,)) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = ag.stack(grad_list[i]) + + # This block just adjusts the shapes to match for multiplication. + r = normalized_rewards + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + if r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return ag.stack(step_counts) + + +@ag.convert() +def graph_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + i = tf.constant(0) + mean_steps_per_iteration = [] + ag.set_element_type(mean_steps_per_iteration, tf.int32) + + while i < iterations: + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + i += 1 + + return ag.stack(mean_steps_per_iteration) + + +class GraphGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ()) + obs = tf.reshape(obs, (1, 4)) + obs = tf.cast(obs, tf.float32) + return obs + + def step(self, actions): + + def take_one_step(actions): + obs, reward, done, _ = self.env.step(actions) + obs = obs.astype(np.float32) + reward = np.float32(reward) + return obs, reward, done + + return ag.utils.wrap_py_func(take_one_step, + (tf.float32, tf.float32, tf.bool), (actions,)) + + +class GraphRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return tf.random.normal((1, 4)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = tf.random.normal((1, 4)) + fixed_reward = tf.constant(0.001) + done = tf.constant(False) + return random_obs, fixed_reward, done + + +# +# Eager implementation +# + + +def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate): + cdr = 0.0 + reverse_discounted = [] + + for i in range(len(rewards) - 1, -1, -1): + cdr = cdr * discount_rate + rewards[i] + reverse_discounted.append(cdr) + + discounted_rewards.extend(reversed(reverse_discounted)) + return discounted_rewards + + +class EagerPolicyNetwork(tf.keras.Model): + """Policy network for the cart-pole reinforcement learning problem. + + The forward path of the network takes an observation from the cart-pole + environment (length-4 vector) and outputs an action. + """ + + def __init__(self, hidden_size): + super(EagerPolicyNetwork, self).__init__() + self._hidden_layer = tf.keras.layers.Dense( + hidden_size, activation=tf.nn.elu) + self._output_layer = tf.keras.layers.Dense(1) + + def call(self, inputs): + """Calculates logits and action. + + Args: + inputs: Observations from a step in the cart-pole environment, of shape + `(batch_size, input_size)` + + Returns: + logits: the logits output by the output layer. This can be viewed as the + likelihood vales of choosing the left (0) action. Shape: + `(batch_size, 1)`. + actions: randomly selected actions ({0, 1}) based on the logits. Shape: + `(batch_size, 1)`. + """ + hidden = self._hidden_layer(inputs) + logits = self._output_layer(hidden) + + left_prob = tf.nn.sigmoid(logits) + action_probs = tf.concat([left_prob, 1.0 - left_prob], 1) + + self._grad_fn = eager.implicit_gradients( + self._get_cross_entropy_and_save_actions) + + actions = tf.multinomial(tf.log(action_probs), 1) + return logits, actions + + def _get_cross_entropy_and_save_actions(self, inputs): + logits, actions = self(inputs) # pylint:disable=not-callable + self._current_actions = actions + labels = 1.0 - tf.cast(actions, tf.float32) + return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + + def train(self, cart_pole_env, optimizer, discount_rate, num_games, + max_steps_per_game): + grad_list = None + + step_counts = [] + discounted_rewards = [] + + for _ in range(num_games): + obs = cart_pole_env.reset() + + game_rewards = [] + + for step in range(max_steps_per_game): + grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32)) + grads, var_list = zip(*grads_and_vars) + actions = self._current_actions.numpy()[0][0] + + if grad_list is None: + grad_list = [[g] for g in grads] + else: + for i in range(len(grads)): + grad_list[i].append(grads[i]) + + obs, reward, done = cart_pole_env.step(actions) + + game_rewards.append(reward) + if reward < 0.1 or done: + step_counts.append(step + 1) + break + + discounted_rewards = eager_append_discounted_rewards( + discounted_rewards, game_rewards, discount_rate) + + discounted_rewards = tf.stack(discounted_rewards) + mean, variance = tf.nn.moments(discounted_rewards, [0]) + normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance) + + for i in range(len(grad_list)): + g = tf.stack(grad_list[i]) + + r = normalized_rewards + while r.shape.ndims < g.shape.ndims: + r = tf.expand_dims(r, -1) + + grad_list[i] = tf.reduce_mean(g * r, axis=0) + + optimizer.apply_gradients( + zip(grad_list, var_list), global_step=tf.train.get_global_step()) + + return tf.stack(step_counts) + + +def eager_train_model(policy_network, cart_pole_env, optimizer, iterations): + """Trains the policy network for a given number of iterations.""" + mean_steps_per_iteration = [] + + for _ in range(iterations): + steps_per_game = policy_network.train( + cart_pole_env, + optimizer, + discount_rate=0.95, + num_games=20, + max_steps_per_game=200) + mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game)) + + return mean_steps_per_iteration + + +class EagerGymCartpoleEnv(object): + """An env backed by OpenAI Gym's CartPole environment. + + Used to confirm a functional model only. + """ + + def __init__(self): + cart_pole_env = gym.make('CartPole-v1') + cart_pole_env.seed(0) + cart_pole_env.reset() + self.env = cart_pole_env + + def reset(self): + return self.env.reset() + + def step(self, actions): + obs, reward, done, _ = self.env.step(actions) + return obs, reward, done + + +class EagerRandomCartpoleEnv(object): + """An environment that returns random actions and never finishes. + + Used during benchmarking, it will cause training to run a constant number of + steps. + """ + + def reset(self): + return np.random.normal(size=(4,)) + + def step(self, actions): + with tf.control_dependencies([actions]): + random_obs = np.random.normal(size=(4,)) + fixed_reward = 0.001 + done = False + return random_obs, fixed_reward, done + + +def graph_demo_training(): + """Not used in the benchmark. Used to confirm a functional model.""" + with tf.Graph().as_default(): + tf.set_random_seed(0) + + network = GraphPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = GraphGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + train_ops = graph_train_model(network, env, opt, iterations=5) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + steps_per_iteration = sess.run(train_ops) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +def eager_demo_training(): + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=5) + network.build((1, 4)) + env = EagerGymCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + steps_per_iteration = eager_train_model(network, env, opt, iterations=5) + for i, steps in enumerate(steps_per_iteration): + print('Step {} iterations: {}'.format(i, steps)) + + +class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark): + """Actual benchmark. + + Trains the RL agent a fixed number of times, on random environments that + result in constant number of steps. + """ + + def benchmark_cartpole(self): + + def train_session(sess, ops): + return lambda: sess.run(ops) + + def train_eager(network, env, opt): + return lambda: eager_train_model(network, env, opt, iterations=10) + + for model_size in (10, 100, 1000): + with tf.Graph().as_default(): + network = GraphPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = GraphRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + train_ops = graph_train_model(network, env, opt, iterations=10) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + + self.time_execution(('cartpole', 'autograph', model_size), + train_session(sess, train_ops), 20) + + with context.eager_mode(): + network = EagerPolicyNetwork(hidden_size=model_size) + network.build((1, 4)) + env = EagerRandomCartpoleEnv() + opt = tf.train.AdamOptimizer(0.05) + + self.time_execution(('cartpole', 'eager', model_size), + train_eager(network, env, opt), 20) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 55faad983f2bcf2f3fa633669bd371608e2e925b..3e4d0dc1cec76b068c1c846eb476eec615e4f613 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,8 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import function +from tensorflow.python.eager import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -101,12 +102,15 @@ def batch_function(num_batch_threads, def decorator(fn): # pylint: disable=missing-docstring def decorated(*args): # pylint: disable=missing-docstring - types = [arg.dtype for arg in args] - @function.Defun(*types) + @function.defun() def computation(*computation_args): return fn(*computation_args) + computation = computation.get_concrete_function( + *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) + for i, x in enumerate(args)]) + with ops.name_scope("batch") as name: for a in args: if not isinstance(a, ops.Tensor): @@ -123,7 +127,7 @@ def batch_function(num_batch_threads, f=computation, in_tensors=list(args), captured_tensors=computation.captured_inputs, - Tout=[o.type for o in computation.definition.signature.output_arg]) + Tout=[o.dtype for o in computation.outputs]) return decorated diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 01ee8703a93836d607ee9b765c51c79fe3bb974f..9109b9c1c91cefa4c52bad49de23336a6e05e1ef 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -219,6 +219,7 @@ class BatchOpsTest(test.TestCase): @batch_ops.batch_function(1, 10, 100000) def computation(in_t): + self.assertTrue(in_t.shape is not None) return in_t + 1 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 41a8c920fc4e81af90f4c94a149d8c404c58b747..493046b39907971e92f91ecc60d375ea273ff1d2 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for representing Bayesian computation. +Use [tfp](/probability/api_docs/python/tfp) instead. + ## This package provides classes for Bayesian computation with TensorFlow. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py index 13215ffabf3a956d3f83697f867457b2fa72e7c9..8b6ed9f041b89a0da02a505ec261bca82b094f74 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py @@ -81,7 +81,7 @@ class ExpectationImportanceSampleTest(test.TestCase): # Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x). # Should equal 1/2 because p is a spherical Gaussian centered at (0, 0). def indicator(x): - x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1]) + x1_times_x2 = math_ops.reduce_prod(x, axis=[-1]) return 0.5 * (math_ops.sign(x1_times_x2) + 1.0) prob = mc.expectation_importance_sampler( diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69..28a829d87ddecc4a147c588b5b0536b44db8393f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers.""" +"""Monte Carlo integration and helpers. + +Use [tfp.monte_carlo](/probability/api_docs/python/tfp/monte_carlo) instead. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 18d40fc1dff8e7c9aefffbe3ceba770598a42096..e83a54851195708eb7e6412b7400236f4bc06e6b 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -353,12 +353,12 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, def _sample_mean(values): """Mean over sample indices. In this module this is always [0].""" - return math_ops.reduce_mean(values, reduction_indices=[0]) + return math_ops.reduce_mean(values, axis=[0]) def _sample_max(values): """Max over sample indices. In this module this is always [0].""" - return math_ops.reduce_max(values, reduction_indices=[0]) + return math_ops.reduce_max(values, axis=[0]) def _get_samples(dist, z, n, seed): diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index 2c44abed5e1955cc666273e97e6b2378766f13d2..79052bee35c7895cb4048b10c1f73acb036d1587 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -51,25 +51,18 @@ BIGTABLE_TABLE_NAME = '' PREFIX = 'train-' def main(): + tf.enable_eager_execution() + client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID) table = client.table(BIGTABLE_TABLE_NAME) dataset = table.keys_by_prefix_dataset(PREFIX) - iterator = dataset.make_initializable_iterator() - get_next_op = iterator.get_next() - with tf.Session() as sess: - print('Initializing the iterator.') - sess.run(iterator.initializer) - print('Retrieving rows:') - row_index = 0 - while True: - try: - row_key = sess.run(get_next_op) - print('Row key %d: %s' % (row_index, row_key)) - row_index += 1 - except tf.errors.OutOfRangeError: - print('Finished reading data!') - break + print('Retrieving rows:') + row_index = 0 + for row_key in dataset: + print('Row key %d: %s' % (row_index, row_key)) + row_index += 1 + print('Finished reading data!') if __name__ == '__main__': main() diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc index 67bf14c17646cff81af707405b66c9fba2ded0bd..98f906408c230a4382ffafe412ee9990d4384930 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -29,8 +29,7 @@ Status GrpcStatusToTfStatus(const ::grpc::Status& status) { } return Status(static_cast<::tensorflow::error::Code>(status.error_code()), strings::StrCat("Error reading from Cloud Bigtable: ", - status.error_message(), - " (Details: ", status.error_details(), ")")); + status.error_message())); } string RegexFromStringSet(const std::vector& strs) { diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 01608dc6bc07890c3a59577ef31c90c2694e6a87..f0c3ef4e2ecbf5f4d21e421be4fb527d4769200c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -167,7 +167,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (index_ > keys_.size() - 2) { + if (index_ + 2 > keys_.size()) { *end_of_sequence = true; return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc index f083ce6f44b3c2a83d9b5d3235056eb94c4be4a8..e95dc577184f7e81d942755b41065f52131ce9f6 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc @@ -366,6 +366,39 @@ BigtableTestClient::MutateRows( return MakeUnique(request.entries_size()); } +std::unique_ptr> +BigtableTestClient::AsyncMutateRow( + grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> +BigtableTestClient::AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + +std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> +BigtableTestClient::AsyncMutateRows( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) { + LOG(WARNING) << "Call to InMemoryDataClient::" << __func__ + << "(); this will likely cause a crash!"; + return nullptr; +} + std::shared_ptr BigtableTestClient::Channel() { LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely " "cause a crash!"; diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h index dac2b16a216d26f02684c7401ed2ddaa4b7baddb..c4a1f06bc504c3565c7bb09b42e48e7fbddb9cc6 100644 --- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h +++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h @@ -61,6 +61,25 @@ class BigtableTestClient : public ::google::cloud::bigtable::DataClient { MutateRows(grpc::ClientContext* context, google::bigtable::v2::MutateRowsRequest const& request) override; + std::unique_ptr> + AsyncMutateRow(grpc::ClientContext* context, + google::bigtable::v2::MutateRowRequest const& request, + grpc::CompletionQueue* cq) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::SampleRowKeysResponse>> + AsyncSampleRowKeys( + ::grpc::ClientContext* context, + const ::google::bigtable::v2::SampleRowKeysRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + + std::unique_ptr<::grpc::ClientAsyncReaderInterface< + ::google::bigtable::v2::MutateRowsResponse>> + AsyncMutateRows(::grpc::ClientContext* context, + const ::google::bigtable::v2::MutateRowsRequest& request, + ::grpc::CompletionQueue* cq, void* tag) override; + std::shared_ptr Channel() override; private: diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py index 316da9ebe152ef52c7e7f846cf8c3eb1555ee8a6..197f5578eb010bee5a3aad7c05446393193f99e2 100644 --- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py +++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py @@ -57,7 +57,7 @@ class BigtableOpsTest(test.TestCase): sess.run(write_op) def runReadKeyTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected = list(self.COMMON_ROW_KEYS) expected.reverse() @@ -78,7 +78,7 @@ class BigtableOpsTest(test.TestCase): self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4")) def runScanTest(self, read_ds): - itr = read_ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(read_ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_keys.reverse() @@ -120,7 +120,7 @@ class BigtableOpsTest(test.TestCase): def testLookup(self): ds = self._table.keys_by_prefix_dataset("r") ds = ds.apply(self._table.lookup_columns(cf1="c1")) - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_keys = list(self.COMMON_ROW_KEYS) expected_values = list(self.COMMON_VALUES) @@ -141,7 +141,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeys(self): ds = self._table.sample_keys() - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() expected_key = self.COMMON_ROW_KEYS[0] with self.cached_session() as sess: @@ -161,7 +161,7 @@ class BigtableOpsTest(test.TestCase): sess.run(n) def runSampleKeyPairsTest(self, ds, expected_key_pairs): - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -218,7 +218,7 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndStartKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="r1", end="") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) @@ -226,14 +226,14 @@ class BigtableOpsTest(test.TestCase): def testSampleKeyPairsPrefixAndEndKey(self): ds = bigtable_api._BigtableSampleKeyPairsDataset( self._table, prefix="r", start="", end="r3") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(itr.initializer) def testParallelScanPrefix(self): ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) @@ -251,7 +251,7 @@ class BigtableOpsTest(test.TestCase): def testParallelScanRange(self): ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1") - itr = ds.make_initializable_iterator() + itr = dataset_ops.make_initializable_iterator(ds) n = itr.get_next() with self.cached_session() as sess: self._writeCommonValues(sess) diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 7c87b0daeb09950cc44c51f49c16534d413f0376..b6cdc7aab0320fe5f457288ada03a46e18a694cc 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -35,8 +35,8 @@ from tensorflow.contrib.util import loader from tensorflow.python.data.experimental.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import resource_loader @@ -111,8 +111,7 @@ class BigtableClient(object): class BigtableTable(object): - """BigtableTable is the entrypoint for reading and writing data in Cloud - Bigtable. + """Entry point for reading and writing data in Cloud Bigtable. This BigtableTable class is the Python representation of the Cloud Bigtable table within TensorFlow. Methods on this class allow data to be read from and @@ -222,7 +221,7 @@ class BigtableTable(object): A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all of the row keys matching that prefix. """ - return _BigtablePrefixKeyDataset(self, prefix) + return dataset_ops.DatasetV1Adapter(_BigtablePrefixKeyDataset(self, prefix)) def sample_keys(self): """Retrieves a sampling of row keys from the Bigtable table. @@ -234,7 +233,7 @@ class BigtableTable(object): Returns: A `tf.data.Dataset` returning string row keys. """ - return _BigtableSampleKeysDataset(self) + return dataset_ops.DatasetV1Adapter(_BigtableSampleKeysDataset(self)) def scan_prefix(self, prefix, probability=None, columns=None, **kwargs): """Retrieves row (including values) from the Bigtable service. @@ -279,7 +278,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, prefix, "", "", normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, prefix, "", "", normalized, probability)) def scan_range(self, start, end, probability=None, columns=None, **kwargs): """Retrieves rows (including values) from the Bigtable service. @@ -324,7 +324,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - return _BigtableScanDataset(self, "", start, end, normalized, probability) + return dataset_ops.DatasetV1Adapter( + _BigtableScanDataset(self, "", start, end, normalized, probability)) def parallel_scan_prefix(self, prefix, @@ -380,7 +381,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "") + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, prefix, "", "")) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -442,7 +444,8 @@ class BigtableTable(object): """ probability = _normalize_probability(probability) normalized = _normalize_columns(columns, kwargs) - ds = _BigtableSampleKeyPairsDataset(self, "", start, end) + ds = dataset_ops.DatasetV1Adapter( + _BigtableSampleKeyPairsDataset(self, "", start, end)) return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability, normalized) @@ -589,16 +592,8 @@ class _BigtableKeyDataset(dataset_ops.DatasetSource): self._table = table @property - def output_classes(self): - return ops.Tensor - - @property - def output_shapes(self): - return tensor_shape.TensorShape([]) - - @property - def output_types(self): - return dtypes.string + def _element_structure(self): + return structure.TensorStructure(dtypes.string, []) class _BigtablePrefixKeyDataset(_BigtableKeyDataset): @@ -658,16 +653,9 @@ class _BigtableLookupDataset(dataset_ops.DatasetSource): self._columns = [i[1] for i in normalized] @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): # pylint: disable=protected-access @@ -693,16 +681,9 @@ class _BigtableScanDataset(dataset_ops.DatasetSource): self._num_outputs = len(normalized) + 1 # 1 for row key @property - def output_classes(self): - return tuple([ops.Tensor] * self._num_outputs) - - @property - def output_shapes(self): - return tuple([tensor_shape.TensorShape([])] * self._num_outputs) - - @property - def output_types(self): - return tuple([dtypes.string] * self._num_outputs) + def _element_structure(self): + return structure.NestedStructure(tuple( + [structure.TensorStructure(dtypes.string, [])] * self._num_outputs)) def _as_variant_tensor(self): return gen_bigtable_ops.bigtable_scan_dataset( @@ -726,16 +707,10 @@ class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource): self._end = end @property - def output_classes(self): - return (ops.Tensor, ops.Tensor) - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) - - @property - def output_types(self): - return (dtypes.string, dtypes.string) + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index f03eab510c2f9010fc92eb1934ac77dc0626a44b..f7f15a302a00ee4187d57fc4d40727b84e6c587c 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -98,7 +98,6 @@ py_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", - "//tensorflow/contrib/stateless", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -108,6 +107,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:stateless_random_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 14b6fc4ac26f74f54628ae37ad6437c7d3e8caba..d3b23d949ee2c7674c3918d39e8b71d76eefcfec 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -132,6 +132,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":custom_loss_head", ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 48f12a64f94c7bd0531488ef537b199558e17e3e..b314b4d74df882a421d9a2ecce2629a63d5c5248 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -41,7 +41,8 @@ def make_custom_export_strategy(name, convert_fn, feature_columns, export_input_fn, - use_core_columns=False): + use_core_columns=False, + feature_engineering_fn=None): """Makes custom exporter of GTFlow tree format. Args: @@ -52,6 +53,7 @@ def make_custom_export_strategy(name, export_input_fn: A function that takes no arguments and returns an `InputFnOps`. use_core_columns: A boolean, whether core feature columns were used. + feature_engineering_fn: Feature eng function to be called on the input. Returns: An `ExportStrategy`. @@ -59,9 +61,12 @@ def make_custom_export_strategy(name, base_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() + features = input_fn.features + if feature_engineering_fn is not None: + features, _ = feature_engineering_fn(features, labels=None) (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns, use_core_columns) + features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" @@ -196,6 +201,10 @@ def convert_to_universal_format(dtec, sorted_feature_names, matching_id = categorical_test.value.add() matching_id.int64_value = split.feature_id node.custom_left_child_test.Pack(categorical_test) + elif (node_type == "oblivious_dense_float_binary_split" or + node_type == "oblivious_categorical_id_binary_split"): + raise ValueError("Universal tree format doesn't support oblivious " + "trees") else: raise ValueError("Unexpected node type %s" % node_type) node.left_child_id.value = split.left_id @@ -229,6 +238,13 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + num_sparse_float] + elif node_type == "oblivious_dense_float_binary_split": + split = tree_node.oblivious_dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "oblivious_categorical_id_binary_split": + split = tree_node.oblivious_categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] elif node_type == "categorical_id_set_membership_binary_split": split = tree_node.categorical_id_set_membership_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 194a5c8754cb0ab2db299e3fb5c998c0f27f8435..358404cd946bbc56d2f7228be8fe4223749c850b 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -28,7 +28,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss -from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -37,7 +36,7 @@ from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn -from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -170,6 +169,7 @@ def _dnn_tree_combined_model_fn( if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and not use_core_versions): raise ValueError("You must use core versions with Estimator Spec") + global_step = training_util.get_global_step() with variable_scope.variable_scope( dnn_parent_scope, @@ -191,46 +191,58 @@ def _dnn_tree_combined_model_fn( feature_columns=dnn_feature_columns, weight_collections=[dnn_parent_scope], scope=input_layer_scope) - previous_layer = input_layer - for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + def dnn_logits_fn(): + """Builds the logits from the input layer.""" + previous_layer = input_layer + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + "hiddenlayer_%d" % layer_id, + values=(previous_layer,)) as hidden_layer_scope: + net = layers.fully_connected( + previous_layer, + num_hidden_units, + activation_fn=dnn_activation_fn, + variables_collections=[dnn_parent_scope], + scope=hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) + _add_hidden_layer_summary(net, hidden_layer_scope.name) + previous_layer = net with variable_scope.variable_scope( - "hiddenlayer_%d" % layer_id, - values=(previous_layer,)) as hidden_layer_scope: - net = layers.fully_connected( + "logits", values=(previous_layer,)) as logits_scope: + dnn_logits = layers.fully_connected( previous_layer, - num_hidden_units, - activation_fn=dnn_activation_fn, + head.logits_dimension, + activation_fn=None, variables_collections=[dnn_parent_scope], - scope=hidden_layer_scope) - if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: - net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) - _add_hidden_layer_summary(net, hidden_layer_scope.name) - previous_layer = net - with variable_scope.variable_scope( - "logits", values=(previous_layer,)) as logits_scope: - dnn_logits = layers.fully_connected( - previous_layer, - head.logits_dimension, - activation_fn=None, - variables_collections=[dnn_parent_scope], - scope=logits_scope) - _add_hidden_layer_summary(dnn_logits, logits_scope.name) - - def _dnn_train_op_fn(loss): - """Returns the op to optimize the loss.""" - return optimizers.optimize_loss( - loss=loss, - global_step=training_util.get_global_step(), - learning_rate=_DNN_LEARNING_RATE, - optimizer=_get_optimizer(dnn_optimizer), - name=dnn_parent_scope, - variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), - # Empty summaries to prevent optimizers from logging training_loss. - summaries=[]) + scope=logits_scope) + _add_hidden_layer_summary(dnn_logits, logits_scope.name) + return dnn_logits + if predict_with_tree_only and mode == model_fn.ModeKeys.INFER: + dnn_logits = array_ops.constant(0.0) + dnn_train_op_fn = control_flow_ops.no_op + elif predict_with_tree_only and mode == model_fn.ModeKeys.EVAL: + dnn_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: array_ops.constant(0.0), + dnn_logits_fn) + dnn_train_op_fn = control_flow_ops.no_op + else: + dnn_logits = dnn_logits_fn() + def dnn_train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizers.optimize_loss( + loss=loss, + global_step=training_util.get_global_step(), + learning_rate=_DNN_LEARNING_RATE, + optimizer=_get_optimizer(dnn_optimizer), + name=dnn_parent_scope, + variables=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), + # Empty summaries to prevent optimizers from logging training_loss. + summaries=[]) # Build Tree Logits. - global_step = training_util.get_global_step() with ops.device(global_step.device): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, @@ -261,8 +273,13 @@ def _dnn_tree_combined_model_fn( """Returns the op to optimize the loss.""" if dnn_to_tree_distillation_param: loss_weight, loss_fn = dnn_to_tree_distillation_param - weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access - features, head.weight_column_name) + # pylint: disable=protected-access + if use_core_versions: + weight_tensor = head_lib._weight_tensor(features, head._weight_column) + else: + weight_tensor = head_lib._weight_tensor( + features, head.weight_column_name) + # pylint: enable=protected-access dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) if loss_fn is None: @@ -305,52 +322,26 @@ def _dnn_tree_combined_model_fn( finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS: - if use_core_versions: - model_fn_ops = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_estimator_spec( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits) - dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - dnn_train_op).train_op - - tree_train_op = head.create_estimator_spec( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits) - tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( - tree_train_op).train_op - - model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( - model_fn_ops) - else: - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + if mode != model_fn.ModeKeys.TRAIN: + return model_fn_ops + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op # Add the hooks model_fn_ops.training_hooks.extend([ @@ -369,11 +360,13 @@ def _dnn_tree_combined_model_fn( labels=labels, train_op_fn=_no_train_op_fn, logits=tree_train_logits) + if mode != model_fn.ModeKeys.TRAIN: + return fusion_spec dnn_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, - train_op_fn=_dnn_train_op_fn, + train_op_fn=dnn_train_op_fn, logits=dnn_logits) tree_spec = head.create_estimator_spec( features=tree_features, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 839eedd3a87ccaa1faecd1966fe5907d682cac02..dea19b7c62649679f944809b44c51ba0cd361904 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -18,13 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator import exporter from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.ops import parsing_ops from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,6 +38,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.training import checkpoint_utils + def _train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.]]) @@ -103,35 +108,6 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) - def testFitAndEvaluateDontThrowExceptionWithCore(self): - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - learner_config.constraints.max_tree_depth = 1 - model_dir = tempfile.mkdtemp() - config = run_config.RunConfig() - - # Use core head - head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( - loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) - - classifier = estimator.DNNBoostedTreeCombinedEstimator( - head=head_fn, - dnn_hidden_units=[1], - # Use core feature columns - dnn_feature_columns=[core_feature_column.numeric_column("x")], - tree_learner_config=learner_config, - num_trees=1, - tree_examples_per_layer=3, - model_dir=model_dir, - config=config, - dnn_steps_to_train=10, - dnn_input_layer_to_tree=True, - tree_feature_columns=[], - use_core_versions=True) - - classifier.fit(input_fn=_train_input_fn, steps=15) - classifier.evaluate(input_fn=_eval_input_fn, steps=1) - def testFitAndEvaluateWithDistillation(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 @@ -223,6 +199,51 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): self.assertLess(0.5, res["auc"]) est.predict(input_fn=_eval_input_fn) + def testTrainEvaluateWithDnnForInputAndTreeForPredict(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + est = estimator.CoreDNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + predict_with_tree_only=True, + dnn_to_tree_distillation_param=(0.5, None), + tree_feature_columns=[]) + + # Train for a few steps. + est.train(input_fn=_train_input_fn, steps=1000) + res = est.evaluate(input_fn=_eval_input_fn, steps=1) + self.assertLess(0.5, res["auc"]) + est.predict(input_fn=_eval_input_fn) + serving_input_fn = ( + export.build_parsing_serving_input_receiver_fn( + feature_spec={"x": parsing_ops.FixedLenFeature( + [1], dtype=dtypes.float32)})) + base_exporter = exporter.FinalExporter( + name="Servo", + serving_input_receiver_fn=serving_input_fn, + assets_extra=None) + export_path = os.path.join(model_dir, "export") + base_exporter.export( + est, + export_path=export_path, + checkpoint_path=None, + eval_result={}, + is_the_final_export=True) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 4c7a538b385ec19f520bff79bab20a121221c60f..a178820841c4c8bcb7f5742babdb6d0f4825de31 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.boosted_trees.estimator_batch import model from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -26,7 +28,8 @@ from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.ops import math_ops from tensorflow.python.ops.losses import losses as core_losses - +from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head +from tensorflow.python.ops import array_ops # ================== Old estimator interface=================================== # The estimators below were designed for old feature columns and old estimator @@ -414,30 +417,167 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): config=config, feature_engineering_fn=feature_engineering_fn) +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + use_core_libs=False, + output_leaf_index=False, + override_global_step_value=None, + num_quantiles=100): + """Initializes a GradientBoostedDecisionTreeQuantileRegressor instance. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + override_global_step_value: If after the training is done, global step + value must be reset to this value. This should be used to reset global + step to a number > number of steps used to train the current ensemble. + For example, the usual way is to train a number of trees and set a very + large number of training steps. When the training is done (number of + trees were trained), this parameter can be used to set the global step + to a large value, making it look like that number of training steps ran. + If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + """ + + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _quantile_regression_head(quantile): + # Use quantile regression. + head = custom_loss_head.CustomLossHead( + loss_fn=functools.partial( + losses.per_example_quantile_regression_loss, quantile=quantile), + link_fn=array_ops.identity, + logit_dimension=label_dimension) + return head + + learner_config.num_classes = max(2, label_dimension) + + super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=model.model_builder, + params={ + 'head': _quantile_regression_head(quantiles[0]), + 'feature_columns': feature_columns, + 'learner_config': learner_config, + 'num_trees': num_trees, + 'weight_column_name': weight_column_name, + 'examples_per_layer': examples_per_layer, + 'logits_modifier_function': logits_modifier_function, + 'center_bias': center_bias, + 'use_core_libs': use_core_libs, + 'output_leaf_index': False, + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, + }, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) + # ================== New Estimator interface=================================== # The estimators below use new core Estimator interface and must be used with # new feature columns and heads. + # For multiclass classification, use the following head since it uses loss # that is twice differentiable. -def core_multiclass_head(n_classes): +def core_multiclass_head( + n_classes, + weight_column=None, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): """Core head for multiclass problems.""" def loss_fn(labels, logits): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=None, num_classes=n_classes) + labels=labels, + logits=logits, + weights=weight_column, + num_classes=n_classes) return result[0] # pylint:disable=protected-access head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( n_classes=n_classes, loss_fn=loss_fn, - loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=loss_reduction, + weight_column=weight_column) # pylint:enable=protected-access return head_fn +# For quantile regression, use this head with Core..Estimator, or use +# Core..QuantileRegressor directly, +def core_quantile_regression_head( + quantiles, + label_dimension=1, + weight_column=None, + loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS): + """Core head for quantile regression problems.""" + + def loss_fn(labels, logits): + result = losses.per_example_quantile_regression_loss( + labels=labels, + predictions=logits, + weights=weight_column, + quantile=quantiles) + return result[0] + + # pylint:disable=protected-access + head_fn = core_head_lib._regression_head( + label_dimension=label_dimension, + loss_fn=loss_fn, + loss_reduction=loss_reduction, + weight_column=weight_column) + # pylint:enable=protected-access + return head_fn + + class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): """An estimator using gradient boosted decision trees. @@ -601,3 +741,104 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): super(CoreGradientBoostedDecisionTreeRanker, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +# When using this estimator, make sure to regularize the hessian (at least l2, +# min_node_weight)! +# TODO(nponomareva): extend to take multiple quantiles in one go. +class CoreGradientBoostedDecisionTreeQuantileRegressor( + core_estimator.Estimator): + """An estimator that does quantile regression and returns quantile estimates. + """ + + def __init__(self, + learner_config, + examples_per_layer, + quantiles, + label_dimension=1, + num_trees=None, + feature_columns=None, + weight_column_name=None, + model_dir=None, + config=None, + label_keys=None, + feature_engineering_fn=None, + logits_modifier_function=None, + center_bias=True, + output_leaf_index=False, + num_quantiles=100): + """Initializes a core version of GradientBoostedDecisionTreeEstimator. + + Args: + learner_config: A config for the learner. + examples_per_layer: Number of examples to accumulate before growing a + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. + quantiles: a list of quantiles for the loss, each between 0 and 1. + label_dimension: Dimension of regression label. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). When label_dimension>1, it is + recommended to use multiclass strategy diagonal hessian or full hessian. + num_trees: An int, number of trees to build. + feature_columns: A list of feature columns. + weight_column_name: Name of the column for weights, or None if not + weighted. + model_dir: Directory for model exports, etc. + config: `RunConfig` object to configure the runtime settings. + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + logits_modifier_function: A modifier function for the logits. + center_bias: Whether a separate tree should be created for first fitting + the bias. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. + """ + if len(quantiles) > 1: + raise ValueError('For now, just one quantile per estimator is supported') + + def _model_fn(features, labels, mode, config): + return model.model_builder( + features=features, + labels=labels, + mode=mode, + config=config, + params={ + 'head': + core_quantile_regression_head( + quantiles[0], label_dimension=label_dimension), + 'feature_columns': + feature_columns, + 'learner_config': + learner_config, + 'num_trees': + num_trees, + 'weight_column_name': + weight_column_name, + 'examples_per_layer': + examples_per_layer, + 'center_bias': + center_bias, + 'logits_modifier_function': + logits_modifier_function, + 'use_core_libs': + True, + 'output_leaf_index': + output_leaf_index, + 'override_global_step_value': + None, + 'num_quantiles': + num_quantiles, + }, + output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) + + super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index c155128c0e4ccf928349ee6453baff4384222096..ee052ac60387d8f993e4942dd7dff39e191dd3a4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -47,8 +48,8 @@ def _multiclass_train_input_fn(): features = { "x": constant_op.constant([[2.], [1.], [1.], [5.], [3.5], [4.6], [3.5]]) } - label = constant_op.constant( - [[1], [0], [0], [2], [2], [0], [1]], dtype=dtypes.int32) + label = constant_op.constant([[1], [0], [0], [2], [2], [0], [1]], + dtype=dtypes.int32) return features, label @@ -77,6 +78,59 @@ def _infer_ranking_train_input_fn(): return features, None +_QUANTILE_REGRESSION_SIZE = 1000 + + +def _quantile_regression_input_fns(two_dimension=False): + # The data generation is taken from + # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html + np.random.seed(1) + + def f(x): + """The function to predict.""" + return x * np.sin(x) + + def g(x): + """The function to predict.""" + return x * np.cos(x) + + # Training data. + x = np.atleast_2d(np.random.uniform(0, 10.0, + size=_QUANTILE_REGRESSION_SIZE)).T + x = x.astype(np.float32) + + # Labels. + if not two_dimension: + y = f(x).ravel() + else: + y = np.column_stack((f(x).ravel(), g(x).ravel())) + + # Add random noise. + dy = 1.5 + 1.0 * np.random.random(y.shape) + noise = np.random.normal(0, dy) + y += noise + y_original = y.astype(np.float32) + if not two_dimension: + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + train_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=None, + shuffle=True) + + # Test on the training data to make sure the predictions are calibrated. + test_input_fn = numpy_io.numpy_input_fn( + x=x, + y=y, + batch_size=_QUANTILE_REGRESSION_SIZE, + num_epochs=1, + shuffle=False) + + return train_input_fn, test_input_fn, y_original + + class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): def setUp(self): @@ -341,6 +395,130 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): for prediction_dict in result_iter: self.assertTrue("classes" in prediction_dict) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + + # 95% percentile. + model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["scores"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + label_dimension=2, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.fit(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["scores"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -489,8 +667,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): feature_columns = [ core_feature_column.weighted_categorical_column( - categorical_column=core_feature_column. - categorical_column_with_vocabulary_list( + categorical_column=core_feature_column + .categorical_column_with_vocabulary_list( key="word", vocabulary_list=["the", "cat", "dog"]), weight_feature_key="weight") ] @@ -509,8 +687,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): # Weights for the words are 5 - cat, 6- dog and 1 -the. features_dict["word"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], - values=constant_op.constant( - ["the", "cat", "dog", "the"], dtype=dtypes.string), + values=constant_op.constant(["the", "cat", "dog", "the"], + dtype=dtypes.string), dense_shape=[4, 3]) features_dict["weight"] = sparse_tensor.SparseTensor( indices=[[0, 0], [0, 1], [1, 0], [3, 0]], @@ -534,6 +712,132 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): est.evaluate(input_fn=input_fn, steps=1) est.predict(input_fn=input_fn) + # One dimensional quantile regression. + def testQuantileRegression(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns() + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + frac_below_upper = round(1. * np.count_nonzero(upper > y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper >= 0.92) + self.assertTrue(frac_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns() + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + frac_above_lower = round(1. * np.count_nonzero(lower < y) / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower >= 0.92) + self.assertTrue(frac_above_lower <= 0.98) + + # Multi-dimensional quantile regression. + def testQuantileRegressionMultiDimLabel(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE + learner_config.regularization.tree_complexity = ( + 1.0 / _QUANTILE_REGRESSION_SIZE) + + train_input_fn, test_input_fn, y = _quantile_regression_input_fns( + two_dimension=True) + y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2) + + # 95% percentile. + model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.95], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_upper.train(input_fn=train_input_fn, steps=1000) + result_iter = model_upper.predict(input_fn=test_input_fn) + upper = [] + for prediction_dict in result_iter: + upper.append(prediction_dict["predictions"]) + + count_below_upper = np.count_nonzero(upper > y, axis=0) + count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1)) + frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3) + frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3) + frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3) + # +/- 3% + self.assertTrue(frac_below_upper_0 >= 0.92) + self.assertTrue(frac_below_upper_0 <= 0.98) + self.assertTrue(frac_below_upper_1 >= 0.92) + self.assertTrue(frac_below_upper_1 <= 0.98) + self.assertTrue(frac_both_below_upper >= 0.92) + self.assertTrue(frac_both_below_upper <= 0.98) + + train_input_fn, test_input_fn, _ = _quantile_regression_input_fns( + two_dimension=True) + model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor( + quantiles=[0.05], + learner_config=learner_config, + num_trees=100, + label_dimension=2, + examples_per_layer=_QUANTILE_REGRESSION_SIZE, + center_bias=False) + + model_lower.train(input_fn=train_input_fn, steps=1000) + result_iter = model_lower.predict(input_fn=test_input_fn) + lower = [] + for prediction_dict in result_iter: + lower.append(prediction_dict["predictions"]) + + count_above_lower = np.count_nonzero(lower < y, axis=0) + count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1)) + frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3) + frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3) + frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3) + # +/- 3% + self.assertTrue(frac_above_lower_0 >= 0.92) + self.assertTrue(frac_above_lower_0 <= 0.98) + self.assertTrue(frac_above_lower_1 >= 0.92) + self.assertTrue(frac_above_lower_1 <= 0.98) + self.assertTrue(frac_both_above_lower >= 0.92) + self.assertTrue(frac_both_above_lower <= 0.98) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 54c4ff059e3408d2cb8fc689a9ae877f57485f58..09b240a7006a8ef53eb95108b3adbfae728cf8fc 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -90,13 +90,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py index e04b56afbfd266dc13a5b0d78d171ea273415ee3..d640af354f55423b7c9706900359f5e64c459f39 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston_combined.py +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -80,13 +80,13 @@ def _make_experiment_fn(output_dir): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data() - train_input_fn = tf.estimator.inputs.numpy_input_fn( + train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_train}, y=y_train, batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( + eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 8edb5d6c640611bbb90d7731b2fea4354e125563..6d78e27e8f69ea289b686af8402bd91967f997f4 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -834,8 +834,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats *= normalizer_ratio; NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; + bool best_feature_updated = false; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); + CHECK(end_index - start_index >= 2) + << "Partition should have a non bias feature. Start index " + << start_index << " and end index " << end_index; + for (int64 feature_idx = start_index + 1; feature_idx < end_index; ++feature_idx) { GradientStats left_gradient_stats(*gradients_t, *hessians_t, @@ -845,11 +850,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { root_gradient_stats - left_gradient_stats; NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats); NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats); - if (left_stats.gain + right_stats.gain > best_gain) { + if (!best_feature_updated || + left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; best_right_node_stats = right_stats; best_feature_idx = feature_idx; + best_feature_updated = true; } } SplitInfo split_info; @@ -864,7 +871,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { << feature_ids(best_feature_idx, 0) << ", " << feature_ids(best_feature_idx, 1) << "\nPartition IDS: " << partition_ids(start_index) << " " - << partition_ids(best_feature_idx); + << partition_ids(best_feature_idx) << " and best gain " << best_gain; equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index ab2853352a70073648f47e9835f8a66852ff584f..a30cfa663f4a4954f83224a7fd6448b369ad93b4 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -382,8 +382,7 @@ class GrowTreeEnsembleOp : public OpKernel { break; } case LearnerConfig::OBLIVIOUS_DECISION_TREE: { - FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list, - &best_splits); + FindBestSplitOblivious(context, gains_list, splits_list, &best_splits); break; } } @@ -475,10 +474,10 @@ class GrowTreeEnsembleOp : public OpKernel { } } - void FindBestSplitsPerPartitionOblivious( - OpKernelContext* const context, const OpInputList& gains_list, - const OpInputList& splits_list, - std::map* best_splits) { + void FindBestSplitOblivious(OpKernelContext* const context, + const OpInputList& gains_list, + const OpInputList& splits_list, + std::map* best_splits) { // Find best split per partition going through every feature candidate. for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { const auto& gains = gains_list[handler_id].vec(); @@ -654,6 +653,12 @@ class GrowTreeEnsembleOp : public OpKernel { return dest; } + if (dest->leaf_case() == boosted_trees::trees::Leaf::LEAF_NOT_SET) { + // No merging is required. Just copy the source weights; + *dest = source; + return dest; + } + // Handle leaf merging based on type. switch (source.leaf_case()) { case boosted_trees::trees::Leaf::kVector: { diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 3028c2281705bd7e34b212332160d25386559d4e..fd832de982a4a7a2bd39e450ad495e60c284ace7 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -67,6 +67,7 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", ], ) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py index 5d4819b0f1cb598cfbe146f569aecd7883186339..efa2ab1dad8df9815c983afaa2e43982a49c5787 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py @@ -19,15 +19,17 @@ from __future__ import division from __future__ import print_function import abc + +import six + from tensorflow.contrib.boosted_trees.python.ops import batch_ops_utils from tensorflow.python.ops import control_flow_ops +@six.add_metaclass(abc.ABCMeta) class BaseSplitHandler(object): """Abstract Base class defining split handlers interface.""" - __metaclass__ = abc.ABCMeta - def __init__(self, l1_regularization, l2_regularization, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 4da25298cb82093ac501997cc21c48265df06860..d26af58419752170bbc58bba757ac43349fc2cff 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -119,7 +119,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) def active_inputs(): diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index a2f708081a4b484d649b5d09b172c2c60db69aeb..386dc19fc7b9529993a9625fb1298f6eb9a70d87 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -36,9 +36,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -486,8 +486,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - values = array_ops.constant([], dtype=dtypes.int64) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) + values = constant_op.constant_v1([], dtype=dtypes.int64) gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f45010ec26ed25127ca78b97f4d6fd7ebd6467ae..0476bed2cd3f3ea5b47b10c51a819f17d6e37c74 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -142,7 +142,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): name="StatsAccumulator/{}".format(self._name)) # Allocate both stats accumulator and quantile accumulator on the same # device so that we can build splits with fewer RPCs. - with ops.colocate_with(self._stats_accumulator.resource()): + with ops.colocate_with(self._stats_accumulator.resource_handle): self._quantile_accumulator = quantile_ops.QuantileAccumulator( init_stamp_token, epsilon=epsilon, @@ -268,8 +268,8 @@ class DenseSplitHandler(InequalitySplitHandler): handler = make_dense_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -447,8 +447,8 @@ class SparseSplitHandler(InequalitySplitHandler): handler = make_sparse_split_tensor are_splits_ready, partition_ids, gains, split_infos = ( - handler(self._quantile_accumulator.resource(), - self._stats_accumulator.resource(), stamp_token, + handler(self._quantile_accumulator.resource_handle, + self._stats_accumulator.resource_handle, stamp_token, next_stamp_token, self._multiclass_strategy, class_id, self._feature_column_group_id, self._l1_regularization, self._l2_regularization, self._tree_complexity_regularization, @@ -605,7 +605,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantile_buckets, example_partition_ids, gradients, hessians, weights, empty_gradients, empty_hessians): """Updates the state for dense split handler.""" - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) quantile_values, quantile_weights = control_flow_ops.cond( is_active[1], # For the next layer, this handler is inactive. @@ -621,8 +621,8 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( @@ -708,11 +708,11 @@ def sparse_make_stats_update( def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + return (constant_op.constant_v1([], dtype=dtypes.int32), + constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]), empty_gradients, empty_hessians) - empty_float = constant_op.constant([], dtype=dtypes.float32) + empty_float = constant_op.constant_v1([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant([0, 1], dtype=dtypes.int64), diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 74b0ea6989c65e83e7a466107d624712a0e72d1b..4a1b528646e7d2139d7eabb0264b8d280f8da133 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -39,9 +39,9 @@ def get_empty_tensors(gradient_shape, hessian_shape): empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) return empty_gradients, empty_hessians @@ -1476,9 +1476,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmpty(self): with self.cached_session() as sess: - indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) + indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. - values = array_ops.constant([], dtype=dtypes.float32) + values = constant_op.constant_v1([], dtype=dtypes.float32) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) gradient_shape = tensor_shape.scalar() @@ -1549,8 +1549,9 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. - empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) - empty_values = array_ops.constant([], dtype=dtypes.float32) + empty_indices = constant_op.constant_v1([], dtype=dtypes.int64, + shape=[0, 2]) + empty_values = constant_op.constant_v1([], dtype=dtypes.float32) empty_sparse_column = sparse_tensor.SparseTensor(empty_indices, empty_values, [4, 2]) empty_sparse_column = empty_sparse_column.eval(session=sess) diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index 64921faf81c0ea8ae7fb1bbec71396ef3408e6ca..de30a7bde792e727ceab7798458566d4527f5867 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -81,9 +81,10 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, const auto& split = current_node.categorical_id_binary_split(); const auto& features = example.sparse_int_features[split.feature_column()]; - node_id = features.find(split.feature_id()) != features.end() - ? split.left_id() - : split.right_id(); + node_id = (std::find(features.begin(), features.end(), + split.feature_id()) == features.end()) + ? split.right_id() + : split.left_id(); break; } case TreeNode::kCategoricalIdSetMembershipBinarySplit: { @@ -117,7 +118,8 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, oblivious_leaf_idx <<= 1; const auto& features = example.sparse_int_features[split.feature_column()]; - if (features.find(split.feature_id()) == features.end()) { + if (std::find(features.begin(), features.end(), split.feature_id()) == + features.end()) { oblivious_leaf_idx++; } node_id++; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index 1371ff337f78dd1c38f2bd0ba86911642f3aeb3e..445ffaaa714c4a69710f9a21d5f2775b8b0f6e22 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -20,6 +20,7 @@ #include #include #include "tensorflow/contrib/boosted_trees/lib/utils/optional_value.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace tensorflow { namespace boosted_trees { @@ -124,7 +125,9 @@ struct Example { // Sparse integer features indexed by feature column. // Note that all integer features are assumed to be categorical, i.e. will // never be compared by order. Also these features can be multivalent. - std::vector> sparse_int_features; + // By default we allocate a InlinedVector of length 1 though since that is + // the most common case. + std::vector> sparse_int_features; }; } // namespace utils diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h index 1b654e1c44e545fb97216ad950f3cd2d3240ffd0..3c5e0fbbb40a916e6a3c4197007fb2b562682aae 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h @@ -148,7 +148,7 @@ class ExamplesIterable { row_range.start); for (int64 row_idx = row_range.start; row_idx < row_range.end; ++row_idx) { - sparse_int_features[sparse_int_idx].insert( + sparse_int_features[sparse_int_idx].push_back( iter_->sparse_int_column_values_[sparse_int_idx](row_idx)); } } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc index 30c37435fe16ef29a9e29202850501098e9ac7f8..2f4f2495eaf799a35fb78e183e545f6a1e2d7790 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc @@ -13,6 +13,7 @@ // limitations under the License. // ============================================================================= #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h" +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -90,8 +91,8 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_EQ(1.0f, example.sparse_float_features[1][1].get_value()); EXPECT_EQ(2, example.sparse_int_features[0].size()); - EXPECT_EQ(1, example.sparse_int_features[0].count(1)); - EXPECT_EQ(1, example.sparse_int_features[0].count(8)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 1)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 8)); EXPECT_EQ(0, example.sparse_int_features[1].size()); } break; case 1: { @@ -105,9 +106,9 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); EXPECT_EQ(1, example.sparse_int_features[0].size()); - EXPECT_EQ(1, example.sparse_int_features[0].count(0)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 0)); EXPECT_EQ(1, example.sparse_int_features[1].size()); - EXPECT_EQ(1, example.sparse_int_features[1].count(7)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[1], 7)); } break; case 2: { EXPECT_EQ(2, example.example_idx); @@ -122,7 +123,7 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); - EXPECT_EQ(1, example.sparse_int_features[1].count(13)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[1], 13)); } break; case 3: { EXPECT_EQ(3, example.example_idx); @@ -136,10 +137,10 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); EXPECT_EQ(2, example.sparse_int_features[0].size()); - EXPECT_EQ(1, example.sparse_int_features[0].count(2)); - EXPECT_EQ(1, example.sparse_int_features[0].count(0)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 2)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 0)); EXPECT_EQ(1, example.sparse_int_features[1].size()); - EXPECT_EQ(1, example.sparse_int_features[1].count(4)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[1], 4)); } break; case 4: { EXPECT_EQ(4, example.example_idx); @@ -154,7 +155,7 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); - EXPECT_EQ(1, example.sparse_int_features[1].count(0)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[1], 0)); } break; case 5: { EXPECT_EQ(5, example.example_idx); @@ -191,7 +192,7 @@ TEST_F(ExamplesIterableTest, Iterate) { EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); EXPECT_EQ(1, example.sparse_int_features[0].size()); - EXPECT_EQ(1, example.sparse_int_features[0].count(5)); + EXPECT_EQ(1, absl::c_count(example.sparse_int_features[0], 5)); } break; default: { LOG(QFATAL) << "Invalid example index."; } break; } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 05ce0884ccfff53484fdc0c26e596e7fb6fcdfd6..356ae337685d580319da16a20bbab27ccaa73255 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -34,7 +34,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -62,7 +62,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2, 1], @@ -91,7 +91,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -123,7 +123,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -133,7 +133,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -164,7 +164,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.scalar(), hessian_shape=tensor_shape.scalar()) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -175,7 +175,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): deserialize = ( - accumulator.deserialize( + accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], @@ -223,7 +223,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -261,7 +261,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -299,7 +299,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -336,7 +336,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], @@ -349,7 +349,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, - hessians_1) = accumulator.serialize() + hessians_1) = accumulator.saveable.serialize() # Make sure that the accumulator hasn't changed during serialization. with ops.control_dependencies([stamp_token]): num_updates_2, partition_2, feature_2, grads_2, hessians_2 = ( @@ -386,7 +386,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): stamp_token=0, gradient_shape=tensor_shape.TensorShape([2]), hessian_shape=tensor_shape.TensorShape([2, 2])) - with ops.control_dependencies([accumulator._create_op]): + with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( stamp_token=0, @@ -399,7 +399,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): 0.08]]]) with ops.control_dependencies([op1]): - deserialize = accumulator.deserialize( + deserialize = accumulator.saveable.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 843420968ac6a6716fdf6b4967146e131139f67c..4dc764f95713ab788c282c2f3e7fb278a24f4822 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +import six + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,11 +29,10 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +@six.add_metaclass(abc.ABCMeta) class ScheduledOp(object): """Represents a scheduled remote operation.""" - __metaclass__ = abc.ABCMeta - @abc.abstractmethod def batching_key(self): """Returns the key for batching operations.""" diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 25b2c9e2fd72bd018717e8a87fce726f26bad968..fca22c71a83459cb290eaebcf107cf1c14c222b7 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import @@ -31,6 +33,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.python.framework import ops from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking ops.NotDifferentiable("TreeEnsembleVariable") ops.NotDifferentiable("TreeEnsembleSerialize") @@ -82,6 +85,44 @@ class TreeEnsembleVariableSavable(saver.BaseSaverBuilder.SaveableObject): tree_ensemble_config=restored_tensors[1]) +class TreeEnsembleVariable(tracking.TrackableResource): + """A Tree ensemble model.""" + + def __init__(self, stamp_token, tree_ensemble_config, name, container=None): + self._stamp_token = stamp_token + self._tree_ensemble_config = tree_ensemble_config + self._name = name + self._container = container + self._init_op = None + super(TreeEnsembleVariable, self).__init__() + + def create_resource(self): + return gen_model_ops.decision_tree_ensemble_resource_handle_op( + self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_model_ops.create_tree_ensemble_variable( + self.resource_handle, self._stamp_token, self._tree_ensemble_config) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_model_ops.tree_ensemble_is_initialized_op(self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return { + "tree_ensemble_variable": + functools.partial( + TreeEnsembleVariableSavable, + tree_ensemble_handle=self.resource_handle, + create_op=self.initializer) + } + + def tree_ensemble_variable(stamp_token, tree_ensemble_config, name, @@ -99,12 +140,11 @@ def tree_ensemble_variable(stamp_token, A `Tensor` of type mutable `string`. The handle to the tree ensemble. """ with ops.name_scope(name, "TreeEnsembleVariable") as name: - resource_handle = gen_model_ops.decision_tree_ensemble_resource_handle_op( - container, shared_name=name, name=name) - create_op = gen_model_ops.create_tree_ensemble_variable( - resource_handle, stamp_token, tree_ensemble_config) - is_initialized_op = gen_model_ops.tree_ensemble_is_initialized_op( - resource_handle) + tree_ensemble_var = TreeEnsembleVariable(stamp_token, tree_ensemble_config, + name, container) + resource_handle = tree_ensemble_var.resource_handle + create_op = tree_ensemble_var.initializer + is_initialized_op = tree_ensemble_var.is_initialized() # Adds the variable to the savable list. saveable = TreeEnsembleVariableSavable(resource_handle, create_op, resource_handle.name) diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 19b6b3296db394b07f57a25dbde187eb9195af38..0c319cc9bd1f720eb404a9da05227c5807ec874f 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -33,59 +33,20 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): - """A resource that allows distributed quantile computation.""" - - def __init__(self, - init_stamp_token, - epsilon, - num_quantiles, - max_elements=None, - name=None, - container=None, - generate_quantiles=False): - """Creates a QuantileAccumulator object. - - Args: - init_stamp_token: The initial value for the stamp token. - epsilon: Error bound on the quantile computation. - num_quantiles: Number of quantiles to produce from the final summary. - max_elements: Maximum number of elements added to the accumulator. - name: the name to save the accumulator under. - container: An optional `string`. Defaults to `""` - generate_quantiles: Generate quantiles instead of approximate boundaries. - If true, exactly `num_quantiles` will be produced in the final summary. - """ - self._epsilon = epsilon - self._generate_quantiles = generate_quantiles +class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for QuantileAccumulator.""" - name = _PATTERN.sub("", name) - with ops.name_scope(name, "QuantileAccumulator") as name: - self._quantile_accumulator_handle = ( - gen_quantile_ops.quantile_stream_resource_handle_op( - container=container, shared_name=name, name=name)) - self._create_op = gen_quantile_ops.create_quantile_accumulator( - self._quantile_accumulator_handle, - init_stamp_token, - epsilon=epsilon, - max_elements=max_elements, - num_quantiles=num_quantiles, - generate_quantiles=generate_quantiles) - is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( - self._quantile_accumulator_handle) - resources.register_resource(self._quantile_accumulator_handle, - self._create_op, is_initialized_op) - self._make_savable(name) - - def _make_savable(self, name): + def __init__(self, resource_handle, create_op, name): + self._resource_handle = resource_handle + self._create_op = create_op stamp_token, state, are_buckets_ready, buckets = ( - gen_quantile_ops.quantile_accumulator_serialize( - self._quantile_accumulator_handle)) + gen_quantile_ops.quantile_accumulator_serialize(resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful in quantile accumulator. slice_spec = "" @@ -96,9 +57,8 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): specs += [make_save_spec(state, "_state")] specs += [make_save_spec(are_buckets_ready, "_are_buckets_ready")] specs += [make_save_spec(buckets, "buckets")] - super(QuantileAccumulator, - self).__init__(self._quantile_accumulator_handle, specs, name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, + specs, name) def restore(self, restored_tensors, unused_restored_shapes): """Restores the associated quantile accumulator from 'restored_tensors'. @@ -119,24 +79,94 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): buckets = restored_tensors[3] with ops.control_dependencies([self._create_op]): return gen_quantile_ops.quantile_accumulator_deserialize( - self._quantile_accumulator_handle, + self._resource_handle, stamp_token=stamp_token, stream_state=state, are_buckets_ready=are_buckets_ready, buckets=buckets) + +class QuantileAccumulator(tracking.TrackableResource): + """A resource that allows distributed quantile computation.""" + + def __init__(self, + init_stamp_token, + epsilon, + num_quantiles, + max_elements=None, + name=None, + container=None, + generate_quantiles=False): + """Creates a QuantileAccumulator object. + + Args: + init_stamp_token: The initial value for the stamp token. + epsilon: Error bound on the quantile computation. + num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. + name: the name to save the accumulator under. + container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. + """ + self._init_stamp_token = init_stamp_token + self._epsilon = epsilon + self._num_quantiles = num_quantiles + self._max_elements = max_elements + self._container = container + self._generate_quantiles = generate_quantiles + super(QuantileAccumulator, self).__init__() + + name = _PATTERN.sub("", name) + with ops.name_scope(name, "QuantileAccumulator") as name: + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self._init_op, + is_initialized_op) + self._saveable = QuantileAccumulatorSaveable(self.resource_handle, + self._init_op, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + def create_resource(self): + return gen_quantile_ops.quantile_stream_resource_handle_op( + container=self._container, shared_name=self._name, name=self._name) + + def initialize(self): + return gen_quantile_ops.create_quantile_accumulator( + self.resource_handle, + self._init_stamp_token, + epsilon=self._epsilon, + max_elements=self._max_elements, + num_quantiles=self._num_quantiles, + generate_quantiles=self._generate_quantiles) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + return gen_quantile_ops.quantile_accumulator_is_initialized( + self.resource_handle) + + def _gather_saveables_for_checkpoint(self): + return {"quantile_accumulator", self.saveable} + def get_buckets(self, stamp_token): """Returns quantile buckets created during previous flush.""" are_buckets_ready, buckets = ( gen_quantile_ops.quantile_accumulator_get_buckets( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token)) return are_buckets_ready[0], buckets[0] def schedule_get_buckets(self): """Returns a scheduled read of buckets created during previous flush.""" return batch_ops_utils.ScheduledStampedResourceOp( - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, op=gen_quantile_ops.quantile_accumulator_get_buckets) def _make_summary(self, column, example_weights): @@ -161,14 +191,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Adds quantile summary to its stream in resource.""" summary = self._make_summary(column, example_weights) return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) def add_prebuilt_summary(self, stamp_token, summary): """Adds quantile summary to its stream in resource.""" return gen_quantile_ops.quantile_accumulator_add_summaries( - quantile_accumulator_handles=[self._quantile_accumulator_handle], + quantile_accumulator_handles=[self.resource_handle], stamp_token=stamp_token, summaries=[summary]) @@ -177,7 +207,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summary = self._make_summary(column, example_weights) return batch_ops_utils.ScheduledStampedResourceOp( op=gen_quantile_ops.quantile_accumulator_add_summaries, - resource_handle=self._quantile_accumulator_handle, + resource_handle=self.resource_handle, summaries=summary) def flush(self, stamp_token, next_stamp_token): @@ -190,17 +220,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) def flush_summary(self, stamp_token, next_stamp_token): """Finalizes quantile summary stream and resets it for next iteration.""" result = gen_quantile_ops.quantile_accumulator_flush_summary( - quantile_accumulator_handle=self._quantile_accumulator_handle, + quantile_accumulator_handle=self.resource_handle, stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result - - def resource(self): - return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 2e94e353f325f06eed2d290d3a7a461861820c39..ad1191d41236e71008bff8c8a7fbd42c16e3f9c5 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -26,12 +26,83 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import tracking # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") -class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): +class StatsAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for StatsAccumulator.""" + + def __init__(self, resource_handle, create_op, is_scalar, name): + self._create_op = create_op + self._resource_handle = resource_handle + self._is_scalar = is_scalar + slice_spec = "" + saver_name = self._resource_handle.name + (stamp_token, num_updates, partition_ids, feature_ids, gradients, + hessians) = self.serialize() + specs = [ + saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, + saver_name + "_stamp"), + saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, + saver_name + "_num_updates"), + saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, + saver_name + "_partition_ids"), + saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, + saver_name + "_feature_ids"), + saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, + saver_name + "_gradients"), + saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, + saver_name + "hessians"), + ] + super(StatsAccumulatorSaveable, self).__init__(self._resource_handle, specs, + name) + + def serialize(self): + """Serializes the stats accumulator state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( + self._resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( + self._resource_handle) + + def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, + gradients, hessians): + """Resets the stats accumulator with the serialized state.""" + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( + self._resource_handle, stamp_token, num_updates, partition_ids, + feature_ids, gradients, hessians) + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree ensemble from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree ensemble variable. + """ + with ops.control_dependencies([self._create_op]): + return self.deserialize( + stamp_token=restored_tensors[0], + num_updates=restored_tensors[1], + partition_ids=restored_tensors[2], + feature_ids=restored_tensors[3], + gradients=restored_tensors[4], + hessians=restored_tensors[5]) + + +class StatsAccumulator(tracking.TrackableResource): """A resource that allows to accumulate gradients and hessians. For consistency guarantees, we use read and write stamp tokens. @@ -58,58 +129,69 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): Returns: A `Tensor` of type mutable `string`. The handle to the stats accumulator. """ + self._stamp_token = stamp_token + self._gradient_shape = gradient_shape + self._hessian_shape = hessian_shape + self._container = container + + if (gradient_shape == tensor_shape.scalar() and + hessian_shape == tensor_shape.scalar()): + self._is_scalar = True + else: + self._is_scalar = False + if name is not None: name = _PATTERN.sub("", name) with ops.name_scope(name, "StatsAccumulator") as name: - # Both values are scalars. - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): - self._is_scalar = True - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_scalar_resource_handle_op( - container, name, name=name)) - - create_op = gen_stats_accumulator_ops.create_stats_accumulator_scalar( - self._resource_handle, stamp_token) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( - self._resource_handle)) - else: - self._is_scalar = False - self._resource_handle = (gen_stats_accumulator_ops. - stats_accumulator_tensor_resource_handle_op( - container, name, name=name)) - create_op = gen_stats_accumulator_ops.create_stats_accumulator_tensor( - self._resource_handle, stamp_token, gradient_shape.as_list(), - hessian_shape.as_list()) - is_initialized_op = ( - gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( - self._resource_handle)) + self._name = name + self._resource_handle = self.create_resource() + self._init_op = self.initialize() + is_initialized_op = self.is_initialized() + resources.register_resource(self.resource_handle, self.initializer, + is_initialized_op) + self._saveable = StatsAccumulatorSaveable( + self.resource_handle, self.initializer, self._is_scalar, name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) - self._create_op = create_op - slice_spec = "" - saver_name = self._resource_handle.name - (stamp_token, num_updates, partition_ids, feature_ids, gradients, - hessians) = self.serialize() - specs = [ - saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, - saver_name + "_stamp"), - saver.BaseSaverBuilder.SaveSpec(num_updates, slice_spec, - saver_name + "_num_updates"), - saver.BaseSaverBuilder.SaveSpec(partition_ids, slice_spec, - saver_name + "_partition_ids"), - saver.BaseSaverBuilder.SaveSpec(feature_ids, slice_spec, - saver_name + "_feature_ids"), - saver.BaseSaverBuilder.SaveSpec(gradients, slice_spec, - saver_name + "_gradients"), - saver.BaseSaverBuilder.SaveSpec(hessians, slice_spec, - saver_name + "hessians"), - ] + def create_resource(self): + if self._is_scalar: + return ( + gen_stats_accumulator_ops.stats_accumulator_scalar_resource_handle_op( + self._container, self._name, name=self._name)) + else: + return ( + gen_stats_accumulator_ops.stats_accumulator_tensor_resource_handle_op( + self._container, self._name, name=self._name)) - super(StatsAccumulator, self).__init__(self._resource_handle, specs, name) - resources.register_resource(self._resource_handle, create_op, - is_initialized_op) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self) + def initialize(self): + if self._is_scalar: + return gen_stats_accumulator_ops.create_stats_accumulator_scalar( + self.resource_handle, self._stamp_token) + else: + return gen_stats_accumulator_ops.create_stats_accumulator_tensor( + self.resource_handle, self._stamp_token, + self._gradient_shape.as_list(), self._hessian_shape.as_list()) + + @property + def initializer(self): + if self._init_op is None: + self._init_op = self.initialize() + return self._init_op + + def is_initialized(self): + if self._is_scalar: + return gen_stats_accumulator_ops.stats_accumulator_scalar_is_initialized( + self.resource_handle) + else: + return gen_stats_accumulator_ops.stats_accumulator_tensor_is_initialized( + self.resource_handle) + + @property + def saveable(self): + return self._saveable + + def _gather_saveables_for_checkpoint(self): + return {"stats_accumulator", self.saveable} def add(self, stamp_token, partition_ids, feature_ids, gradients, hessians): """Updates the stats accumulator.""" @@ -117,11 +199,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): partition_ids, feature_ids, gradients, hessians)) if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_add( - [self._resource_handle], stamp_token, [partition_ids], [feature_ids], + [self.resource_handle], stamp_token, [partition_ids], [feature_ids], [gradients], [hessians]) def schedule_add(self, partition_ids, feature_ids, gradients, hessians): @@ -131,7 +213,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): if self._is_scalar: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_scalar_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -139,7 +221,7 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): else: return batch_ops_utils.ScheduledStampedResourceOp( op=gen_stats_accumulator_ops.stats_accumulator_tensor_add, - resource_handle=self._resource_handle, + resource_handle=self.resource_handle, partition_ids=partition_ids, feature_ids=feature_ids, gradients=gradients, @@ -153,55 +235,11 @@ class StatsAccumulator(saver.BaseSaverBuilder.SaveableObject): return gen_stats_accumulator_ops.stats_accumulator_tensor_make_summary( partition_ids, feature_ids, gradients, hessians) - def deserialize(self, stamp_token, num_updates, partition_ids, feature_ids, - gradients, hessians): - """Resets the stats accumulator with the serialized state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_deserialize( - self._resource_handle, stamp_token, num_updates, partition_ids, - feature_ids, gradients, hessians) - def flush(self, stamp_token, next_stamp_token): """Flushes the stats accumulator.""" if self._is_scalar: return gen_stats_accumulator_ops.stats_accumulator_scalar_flush( - self._resource_handle, stamp_token, next_stamp_token) + self.resource_handle, stamp_token, next_stamp_token) else: return gen_stats_accumulator_ops.stats_accumulator_tensor_flush( - self._resource_handle, stamp_token, next_stamp_token) - - def serialize(self): - """Serializes the stats accumulator state.""" - if self._is_scalar: - return gen_stats_accumulator_ops.stats_accumulator_scalar_serialize( - self._resource_handle) - else: - return gen_stats_accumulator_ops.stats_accumulator_tensor_serialize( - self._resource_handle) - - def restore(self, restored_tensors, unused_restored_shapes): - """Restores the associated tree ensemble from 'restored_tensors'. - - Args: - restored_tensors: the tensors that were loaded from a checkpoint. - unused_restored_shapes: the shapes this object should conform to after - restore. Not meaningful for trees. - - Returns: - The operation that restores the state of the tree ensemble variable. - """ - with ops.control_dependencies([self._create_op]): - return self.deserialize( - stamp_token=restored_tensors[0], - num_updates=restored_tensors[1], - partition_ids=restored_tensors[2], - feature_ids=restored_tensors[3], - gradients=restored_tensors[4], - hessians=restored_tensors[5]) - - def resource(self): - return self._resource_handle + self.resource_handle, stamp_token, next_stamp_token) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 8531e97f90236b8e5eb64bc0f4c9bb3b674f35cd..9fdc2fc0c2c7b85502f7a3f9ae7c85cf05d5916c 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -22,7 +22,6 @@ import collections import copy from tensorflow.contrib import learn -from tensorflow.contrib import stateless from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 @@ -44,6 +43,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import stateless_random_ops as stateless from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses @@ -386,10 +386,21 @@ class GradientBoostedDecisionTreeModel(object): learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and learner_config.pruning_mode == learner_pb2 + .LearnerConfig.PRUNING_MODE_UNSPECIFIED): + learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE + if (learner_config.pruning_mode == learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED): learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and + learner_config.pruning_mode == learner_pb2.LearnerConfig.POST_PRUNE): + raise ValueError( + "Post pruning is not implmented for oblivious decision trees.") + if learner_config.constraints.max_tree_depth == 0: # Use 6 as the default maximum depth. learner_config.constraints.max_tree_depth = 6 @@ -418,6 +429,11 @@ class GradientBoostedDecisionTreeModel(object): sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = extract_features( features, self._feature_columns, use_core_columns) + if (learner_config.weak_learner_type == learner_pb2.LearnerConfig + .OBLIVIOUS_DECISION_TREE and sparse_float_indices): + raise ValueError("Oblivious trees don't handle sparse float features yet." + ) + logging.info("Active Feature Columns: " + str(fc_names)) logging.info("Learner config: " + str(learner_config)) self._fc_names = fc_names @@ -881,9 +897,9 @@ class GradientBoostedDecisionTreeModel(object): empty_hess_shape = [1] + self._hessian_shape.as_list() empty_grad_shape = [1] + self._gradient_shape.as_list() - empty_gradients = constant_op.constant( + empty_gradients = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_grad_shape) - empty_hessians = constant_op.constant( + empty_hessians = constant_op.constant_v1( [], dtype=dtypes.float32, shape=empty_hess_shape) active_handlers = array_ops.unstack(active_handlers, axis=0) @@ -976,7 +992,7 @@ class GradientBoostedDecisionTreeModel(object): # Get accumulated steps and examples for the current layer. _, _, _, _, acc_examples, acc_steps = ( - steps_accumulator.serialize()) + steps_accumulator.saveable.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) ensemble_update_ops.append( @@ -1241,13 +1257,12 @@ class GradientBoostedDecisionTreeModel(object): def _get_replica_device_setter(self, worker_device): """Creates a replica device setter.""" ps_tasks = self._num_ps_replicas - ps_ops = [ - "Variable", - "VariableV2", + ps_ops = list(device_setter.STANDARD_PS_OPS) + ps_ops.extend([ "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - ] + ]) ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( worker_device=worker_device, diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 6d20a2e7f482953481fb1effe4c6e2e5a300786f..92068e88a76cb8bfdd394c1093347a8fb8a63449 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -1257,6 +1257,96 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertArrayNear(expected_leaf_2, output.trees[0].nodes[2].leaf.vector.value, 1e-3) + def testTrainFnMulticlassDiagonalHessianOblivious(self): + """Tests the GBDT train for multiclass diagonal hessian.""" + with self.cached_session(): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 1 + # Use full hessian multiclass strategy. + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + learner_config.num_classes = 5 + learner_config.regularization.l1 = 0 + # To make matrix inversible. + learner_config.regularization.l2 = 1e-5 + learner_config.weak_learner_type = ( + learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE + learner_config.constraints.max_tree_depth = 5 + learner_config.constraints.min_node_weight = 0 + batch_size = 3 + features = {} + features["sparse_int"] = sparse_tensor.SparseTensor( + array_ops.constant([[0, 0], [1, 0]], dtypes.int64), + array_ops.constant([1, 2], dtypes.int64), + array_ops.constant([3, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=5, + features=features) + + labels = array_ops.constant([[2], [2], [3]], dtype=dtypes.float32) + weights = array_ops.ones([batch_size, 1], dtypes.float32) + + predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN) + predictions = predictions_dict["predictions"] + + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + losses.per_example_maxent_loss( + labels, + weights, + predictions, + num_classes=learner_config.num_classes)[0]), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + + # Grow 2 layers. + train_op.run() + train_op.run() + + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output.ParseFromString(serialized.eval()) + self.assertEqual(len(output.trees), 1) + # We got 6 nodes: one parent and 4 leafs. + self.assertEqual(len(output.trees[0].nodes), 6) + self.assertAllClose(output.tree_weights, [1]) + self.assertEqual(stamp_token.eval(), 2) + + print(output.trees[0]) + # Leafs should have a dense vector of size 5. + expected_leaf_1 = [-1.2497, -1.24976, 4.999, -1.24976, -1.2497] + expected_leaf_2 = [-2.2362, -2.2362, 6.0028, -2.2362, -2.2362] + expected_leaf_3 = [-2.2694, -2.2694, 4.0064, -0.0084, -2.2694] + expected_leaf_4 = [-2.2694, -2.2694, -0.0084, 4.0064, -2.2694] + self.assertArrayNear(expected_leaf_1, + output.trees[0].nodes[2].leaf.vector.value, 1e-3) + self.assertArrayNear(expected_leaf_2, + output.trees[0].nodes[3].leaf.vector.value, 1e-3) + self.assertArrayNear(expected_leaf_3, + output.trees[0].nodes[4].leaf.vector.value, 1e-3) + self.assertArrayNear(expected_leaf_4, + output.trees[0].nodes[5].leaf.vector.value, 1e-3) + def testTrainFnMulticlassTreePerClass(self): """Tests the GBDT train for multiclass tree per class strategy.""" with self.cached_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index b5ebaf1999519f65110e8164fa20bace5ecc3ef6..220e981618b7c0bfb1e4e98c087d83b451b9b3cf 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -48,6 +48,47 @@ def per_example_logistic_loss(labels, weights, predictions): labels=labels, logits=predictions) return unweighted_loss * weights, control_flow_ops.no_op() +# MUST USE WITH HESSIAN REGULARIZATION, +# This loss can have zero hessian, so it must be used with l2 or min_node_weight +# regularization. +# An example config is +# learner_config.constraints.min_node_weight = 1 / num_examples_per_layer +# learner_config.regularization.l2 = 1.0 / num_examples_per_layer +# TODO(nponomareva): make it multidimensional so we can estimate several +# quantiles at once. +def per_example_quantile_regression_loss(labels, weights, predictions, + quantile): + """Smoothed loss for quantile regression. + + The standard quantile regression loss is quantile*(y-y') when y>y' and + (quantile-1)*(y-y') otherwise, y' is a prediction, y is a label. The impl + below is this loss but squared in the region where the loss value < 1. + + Args: + labels: Rank 2 (N, D) tensor of per-example labels. + weights: Rank 2 (N, 1) tensor of per-example weights. + predictions: Rank 2 (N, D) tensor of per-example predictions. + quantile: The quantile to use. + + Returns: + loss: A Rank 2 (N, 1) tensor of per-example quantile loss. + update_op: An update operation to update the loss's internal state. + """ + labels = math_ops.to_float(labels) + error = labels - predictions + square_loss_right = array_ops.where(error * quantile < 1.0, + math_ops.square(quantile * error), + quantile * error) + square_loss_left = array_ops.where(error * (quantile - 1) < 1, + math_ops.square((quantile - 1) * error), + (quantile - 1) * error) + + unweighted_loss = array_ops.where(error > 0, square_loss_right, + square_loss_left) + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() # This is classical form of Maximum entropy loss, that is twice differentiable # (sparse_softmax_cross_entropy which is what we go for is not twice @@ -78,8 +119,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): labels = array_ops.expand_dims(labels, 1) # Labels are indices of classes, convert them to one hot encodings. target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes) - labels = math_ops.reduce_sum( - input_tensor=target_one_hot, reduction_indices=[1]) + labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1]) labels = math_ops.to_float(labels) # Calculate softmax probabilities for each class. diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 242c1e8ba45e0b2f6f9a1a51695b824546382666..5418e2605b724edb60878e250d2c50fcc6ff5633 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -46,6 +46,10 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self._maybe_initialize_checkpointable() self._name_counts = {} + @property + def _values(self): + return [dep.ref for dep in self._checkpoint_dependencies] + def track(self, checkpointable, base_name): """Add a dependency on `checkpointable`. diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 707f6211846ca0310bde297603928e9ec5bb471c..f944b7f88438ff257a44581170ead16640540e69 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -21,91 +21,25 @@ py_library( py_library( name = "cluster_resolver_py", - srcs = [ + srcs = glob([ "__init__.py", - "python/training/__init__.py", - ], + "python/training/*.py", + ]), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - ":base_cluster_resolver_py", - ":gce_cluster_resolver_py", - ":tpu_cluster_resolver_py", - "//tensorflow/python:util", - ], -) - -py_library( - name = "base_cluster_resolver_py", - srcs = ["python/training/cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:training", - ], -) - -py_library( - name = "gce_cluster_resolver_py", - srcs = ["python/training/gce_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -py_library( - name = "tpu_cluster_resolver_py", - srcs = ["python/training/tpu_cluster_resolver.py"], - srcs_version = "PY2AND3", - deps = [ - ":base_cluster_resolver_py", - "//tensorflow/python:training", - ], -) - -tf_py_test( - name = "base_cluster_resolver_py_test", - srcs = ["python/training/cluster_resolver_test.py"], - additional_deps = [ - ":cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/cluster_resolver_test.py", + deps = ["//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib"], ) tf_py_test( - name = "gce_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/gce_cluster_resolver_test.py"], + name = "cluster_resolver_initialization_test", + srcs = ["cluster_resolver_initialization_test.py"], additional_deps = [ ":cluster_resolver_py", - ":gce_cluster_resolver_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//tensorflow/python:training", - ], - main = "python/training/gce_cluster_resolver_test.py", -) - -tf_py_test( - name = "tpu_cluster_resolver_py_test", - size = "small", - srcs = ["python/training/tpu_cluster_resolver_test.py"], - additional_deps = [ - ":tpu_cluster_resolver_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], - grpc_enabled = True, - main = "python/training/tpu_cluster_resolver_test.py", + main = "cluster_resolver_initialization_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index b4d8cd4a7cf42e910e7506dbeec8656a2cef62eb..390b3e7550b3d991269bb84707c3500f2fa33290 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -20,11 +20,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import,unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -34,7 +37,10 @@ _allowed_symbols = [ 'SimpleClusterResolver', 'UnionClusterResolver', 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', 'TPUClusterResolver', + 'SlurmClusterResolver', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..01ff1478c694cf0901aeed48b6e0f873d8abe65e --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/cluster_resolver_initialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests to ensure ClusterResolvers are usable via the old contrib path.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cluster_resolver import SimpleClusterResolver +from tensorflow.contrib.cluster_resolver.python.training import cluster_resolver +from tensorflow.contrib.cluster_resolver.python.training import UnionClusterResolver +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib + + +class ClusterResolverInitializationTest(test.TestCase): + + def testCreateSimpleClusterResolverFromLib(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + cluster_resolver.SimpleClusterResolver(base_cluster_spec) + + def testCreateSimpleClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + SimpleClusterResolver(base_cluster_spec) + + def testCreateUnionClusterResolver(self): + base_cluster_spec = server_lib.ClusterSpec({ + "ps": ["ps0:2222", "ps1:2222"], + "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] + }) + simple_cr = SimpleClusterResolver(base_cluster_spec) + UnionClusterResolver(simple_cr) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/__init__.py b/tensorflow/contrib/cluster_resolver/python/training/__init__.py index 0b0464b7d2ddbd26b588bafc9624d412de326f6a..10d93549ebbd4f7e900796d0516b0af1744224af 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/__init__.py +++ b/tensorflow/contrib/cluster_resolver/python/training/__init__.py @@ -18,8 +18,36 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'cluster_resolver', + 'gce_cluster_resolver', + 'kubernetes_cluster_resolver', + 'slurm_cluster_resolver', + 'tfconfig_cluster_resolver', + 'tpu_cluster_resolver', + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'KubernetesClusterResolver', + 'TFConfigClusterResolver', + 'TPUClusterResolver', + 'SlurmClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index 1c480b25134b1e54200e0ddb780bd7bb0f122341..99840fb5166dd739b3bee06a926e06b534011d1f 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,181 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" +"""Stub file for ClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver +from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented -class ClusterResolver(object): - """Abstract class for all implementations of ClusterResolvers. +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', +] - This defines the skeleton for all implementations of ClusterResolvers. - ClusterResolvers are a way for TensorFlow to communicate with various cluster - management systems (e.g. GCE, AWS, etc...). +remove_undocumented(__name__, _allowed_symbols) - By letting TensorFlow communicate with these systems, we will be able to - automatically discover and resolve IP addresses for various TensorFlow - workers. This will eventually allow us to automatically recover from - underlying machine failures and scale TensorFlow worker clusters up and down. - """ - - @abc.abstractmethod - def cluster_spec(self): - """Retrieve the current state of the cluster and returns a ClusterSpec. - - Returns: - A ClusterSpec representing the state of the cluster at the moment this - function is called. - - Implementors of this function must take care in ensuring that the - ClusterSpec returned is up-to-date at the time of calling this function. - This usually means retrieving the information from the underlying cluster - management system every time this function is invoked and reconstructing - a cluster_spec, rather than attempting to cache anything. - """ - raise NotImplementedError( - 'cluster_spec is not implemented for {}.'.format(self)) - - @abc.abstractmethod - def master(self): - """...""" - raise NotImplementedError('master is not implemented for {}.'.format(self)) - - -class SimpleClusterResolver(ClusterResolver): - """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - - def __init__(self, cluster_spec, master=''): - """Creates a SimpleClusterResolver from a ClusterSpec.""" - super(SimpleClusterResolver, self).__init__() - - if not isinstance(cluster_spec, ClusterSpec): - raise TypeError('cluster_spec must be a ClusterSpec.') - self._cluster_spec = cluster_spec - - if not isinstance(master, str): - raise TypeError('master must be a string.') - self._master = master - - def cluster_spec(self): - """Returns the ClusterSpec passed into the constructor.""" - return self._cluster_spec - - def master(self): - """Returns the master address to use when creating a session.""" - return self._master - - -class UnionClusterResolver(ClusterResolver): - """Performs a union on underlying ClusterResolvers. - - This class performs a union given two or more existing ClusterResolvers. It - merges the underlying ClusterResolvers, and returns one unified ClusterSpec - when cluster_spec is called. The details of the merge function is - documented in the cluster_spec function. - """ - - def __init__(self, *args): - """Initializes a UnionClusterResolver with other ClusterResolvers. - - Args: - *args: `ClusterResolver` objects to be unionized. - - Raises: - TypeError: If any argument is not a subclass of `ClusterResolvers`. - ValueError: If there are no arguments passed. - """ - super(UnionClusterResolver, self).__init__() - - if not args: - raise ValueError('At least one ClusterResolver is required.') - - for cluster_resolver in args: - if not isinstance(cluster_resolver, ClusterResolver): - raise TypeError('All arguments must be a sub-class of ' - '`ClusterResolver.`') - self._cluster_resolvers = args - - def cluster_spec(self): - """Returns a union of all the ClusterSpecs from the ClusterResolvers. - - Returns: - A ClusterSpec containing host information merged from all the underlying - ClusterResolvers. - - Raises: - KeyError: If there are conflicting keys detected when merging two or - more dictionaries, this exception is raised. - - Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the - same job name, we will merge the list/dict of workers. - - If *all* underlying ClusterSpecs expose the set of workers as lists, we will - concatenate the lists of workers, starting with the list of workers from - the first ClusterResolver passed into the constructor. - - If *any* of the ClusterSpecs expose the set of workers as a dict, we will - treat all the sets of workers as dicts (even if they are returned as lists) - and will only merge them into a dict if there is no conflicting keys. If - there is a conflicting key, we will raise a `KeyError`. - """ - - merged_cluster = {} - - # We figure out whether it is all lists for a particular job, or whether - # there are dicts inside. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if job_name in merged_cluster: - # If we see a dict, then we write a dict out regardless. - if isinstance(tasks, dict): - merged_cluster[job_name] = {} - else: - # We take whichever type is present. - if isinstance(tasks, list): - merged_cluster[job_name] = [] - else: - merged_cluster[job_name] = {} - - # We then do the merge as appropriate in merged_cluster[job]. - for cluster_resolver in self._cluster_resolvers: - cluster_spec = cluster_resolver.cluster_spec() - cluster_dict = cluster_spec.as_dict() - - for job_name, tasks in cluster_dict.items(): - if isinstance(merged_cluster[job_name], list): - # We all have lists, we can just concatenate and be done. - merged_cluster[job_name].extend(tasks) - else: - if isinstance(tasks, list): - # We convert to a dictionary if the type is a list. - task_dict = dict(zip(range(0, len(tasks)), tasks)) - else: - # We can simply make a copy (for update) and be done. - task_dict = tasks.copy() - - # We detect if there are duplicates, and raise an error if so. - task_keys = set(task_dict) - merged_keys = set(merged_cluster[job_name].keys()) - intersected_keys = task_keys.intersection(merged_keys) - if intersected_keys: - raise KeyError('Duplicate keys detected when merging two ' - 'ClusterSpecs: %s' % repr(intersected_keys)) - - # We do the merge after all the processing. - merged_cluster[job_name].update(task_dict) - - return ClusterSpec(merged_cluster) - - def master(self): - """master returns the master address from the first cluster resolver.""" - return self._cluster_resolvers[0].master() diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py deleted file mode 100644 index d9c97d53eb3663f6ab2f7b40395592dc7638b896..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Cluster Resolvers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver -from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class UnionClusterResolverTest(test.TestCase): - # TODO(frankchn): Transform to parameterized test after it is included in the - # TF open source codebase. - - def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): - self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) - - def testSingleClusterResolver(self): - base_cluster_spec = server_lib.ClusterSpec({ - "ps": ["ps0:2222", "ps1:2222"], - "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] - }) - simple_resolver = SimpleClusterResolver(base_cluster_spec) - union_resolver = UnionClusterResolver(simple_resolver) - - expected_proto = """ - job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } - tasks { key: 1 value: 'ps1:2222' } } - job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } - tasks { key: 1 value: 'worker1:2222' } - tasks { key: 2 value: 'worker2:2222' } } - """ - actual_cluster_spec = union_resolver.cluster_spec() - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - def testTwoNonOverlappingJobMergedClusterResolver(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "ps": [ - "ps0:2222", - "ps1:2222" - ] - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": [ - "worker0:2222", - "worker1:2222", - "worker2:2222" - ] - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - cluster_spec = union_cluster.cluster_spec() - - expected_proto = """ - job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } - tasks { key: 1 value: 'ps1:2222' } } - job { name: 'worker' tasks { key: 0 value: 'worker0:2222' } - tasks { key: 1 value: 'worker1:2222' } - tasks { key: 2 value: 'worker2:2222' } } - """ - self._verifyClusterSpecEquality(cluster_spec, expected_proto) - - def testOverlappingJobMergedClusterResolver(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "worker": [ - "worker4:2222", - "worker5:2222" - ] - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": [ - "worker0:2222", - "worker1:2222", - "worker2:2222" - ] - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - cluster_spec = union_cluster.cluster_spec() - - expected_proto = """ - job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } - tasks { key: 1 value: 'worker5:2222' } - tasks { key: 2 value: 'worker0:2222' } - tasks { key: 3 value: 'worker1:2222' } - tasks { key: 4 value: 'worker2:2222' } } - """ - self._verifyClusterSpecEquality(cluster_spec, expected_proto) - - def testOverlappingSparseJobMergedClusterResolverThrowError(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "worker": { - 7: "worker4:2222", - 9: "worker5:2222" - } - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": { - 3: "worker0:2222", - 6: "worker1:2222", - 7: "worker2:2222" - } - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - self.assertRaises(KeyError, union_cluster.cluster_spec) - - def testOverlappingDictAndListThrowError(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "worker": [ - "worker4:2222", - "worker5:2222" - ] - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": { - 1: "worker0:2222", - 2: "worker1:2222", - 3: "worker2:2222" - } - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - self.assertRaises(KeyError, union_cluster.cluster_spec) - - def testOverlappingJobNonOverlappingKey(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "worker": { - 5: "worker4:2222", - 9: "worker5:2222" - } - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": { - 3: "worker0:2222", - 6: "worker1:2222", - 7: "worker2:2222" - } - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - cluster_spec = union_cluster.cluster_spec() - - expected_proto = """ - job { name: 'worker' tasks { key: 3 value: 'worker0:2222' } - tasks { key: 5 value: 'worker4:2222' } - tasks { key: 6 value: 'worker1:2222' } - tasks { key: 7 value: 'worker2:2222' } - tasks { key: 9 value: 'worker5:2222' }} - """ - self._verifyClusterSpecEquality(cluster_spec, expected_proto) - - def testMixedModeNonOverlappingKey(self): - cluster_spec_1 = server_lib.ClusterSpec({ - "worker": [ - "worker4:2222", - "worker5:2222" - ] - }) - cluster_spec_2 = server_lib.ClusterSpec({ - "worker": { - 3: "worker0:2222", - 6: "worker1:2222", - 7: "worker2:2222" - } - }) - cluster_resolver_1 = SimpleClusterResolver(cluster_spec_1) - cluster_resolver_2 = SimpleClusterResolver(cluster_spec_2) - - union_cluster = UnionClusterResolver(cluster_resolver_1, cluster_resolver_2) - cluster_spec = union_cluster.cluster_spec() - - expected_proto = """ - job { name: 'worker' tasks { key: 0 value: 'worker4:2222' } - tasks { key: 1 value: 'worker5:2222' } - tasks { key: 3 value: 'worker0:2222' } - tasks { key: 6 value: 'worker1:2222' } - tasks { key: 7 value: 'worker2:2222' }} - """ - self._verifyClusterSpecEquality(cluster_spec, expected_proto) - - def testRetainSparseJobWithNoMerging(self): - base_cluster_spec = server_lib.ClusterSpec({ - "worker": { - 1: "worker0:2222", - 3: "worker1:2222", - 5: "worker2:2222" - } - }) - - base_cluster_resolver = SimpleClusterResolver(base_cluster_spec) - union_cluster = UnionClusterResolver(base_cluster_resolver) - cluster_spec = union_cluster.cluster_spec() - - expected_proto = """ - job { name: 'worker' tasks { key: 1 value: 'worker0:2222' } - tasks { key: 3 value: 'worker1:2222' } - tasks { key: 5 value: 'worker2:2222' } } - """ - self._verifyClusterSpecEquality(cluster_spec, expected_proto) - - -# TODO(saeta): Include tests for master resolution - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index 3f5824128948453634bc5e5a7d6fdeedae60f5bd..55e61155c683c928efab9bb018868faec3e3df8c 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,128 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for GCE Instance Groups.""" +"""Stub file for GceClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GceClusterResolver +# pylint: enable=unused-import -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +from tensorflow.python.util.all_util import remove_undocumented +_allowed_symbols = [ + 'GceClusterResolver', +] -class GceClusterResolver(ClusterResolver): - """Cluster Resolver for Google Compute Engine. - - This is an implementation of cluster resolvers for the Google Compute Engine - instance group platform. By specifying a project, zone, and instance group, - this will retrieve the IP address of all the instances within the instance - group and return a Cluster Resolver object suitable for use for distributed - TensorFlow. - """ - - def __init__(self, - project, - zone, - instance_group, - port, - job_name='worker', - credentials='default', - service=None): - """Creates a new GceClusterResolver object. - - This takes in a few parameters and creates a GceClusterResolver project. It - will then use these parameters to query the GCE API for the IP addresses of - each instance in the instance group. - - Args: - project: Name of the GCE project - zone: Zone of the GCE instance group - instance_group: Name of the GCE instance group - port: Port of the listening TensorFlow server (default: 8470) - job_name: Name of the TensorFlow job this set of instances belongs to - credentials: GCE Credentials. If nothing is specified, this defaults to - GoogleCredentials.get_application_default() - service: The GCE API object returned by the googleapiclient.discovery - function. (Default: discovery.build('compute', 'v1')). If you specify a - custom service object, then the credentials parameter will be ignored. - - Raises: - ImportError: If the googleapiclient is not installed. - """ - self._project = project - self._zone = zone - self._instance_group = instance_group - self._job_name = job_name - self._port = port - self._credentials = credentials - - if credentials == 'default': - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'GCE cluster resolver') - self._service = discovery.build( - 'compute', 'v1', - credentials=self._credentials) - else: - self._service = service - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest instance group info. - - This returns a ClusterSpec object for use based on information from the - specified instance group. We will retrieve the information from the GCE APIs - every time this method is called. - - Returns: - A ClusterSpec containing host information retrieved from GCE. - """ - request_body = {'instanceState': 'RUNNING'} - request = self._service.instanceGroups().listInstances( - project=self._project, - zone=self._zone, - instanceGroups=self._instance_group, - body=request_body, - orderBy='name') - - worker_list = [] - - while request is not None: - response = request.execute() - - items = response['items'] - for instance in items: - instance_name = instance['instance'].split('/')[-1] - - instance_request = self._service.instances().get( - project=self._project, - zone=self._zone, - instance=instance_name) - - if instance_request is not None: - instance_details = instance_request.execute() - ip_address = instance_details['networkInterfaces'][0]['networkIP'] - instance_url = '%s:%s' % (ip_address, self._port) - worker_list.append(instance_url) - - request = self._service.instanceGroups().listInstances_next( - previous_request=request, - previous_response=response) - - worker_list.sort() - return ClusterSpec({self._job_name: worker_list}) - - def master(self): - return '' +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..a8eaf33629a6299d5da5f8a930e0cad7d07044e8 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/kubernetes_cluster_resolver.py @@ -0,0 +1,36 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file for KubernetesClusterResolver for backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.kubernetes_cluster_resolver import KubernetesClusterResolver +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'KubernetesClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) + diff --git a/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd2a846eeb1be7ad4b5a98b067a125afbbebc7d --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/slurm_cluster_resolver.py @@ -0,0 +1,35 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file for SlurmClusterResolver to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.slurm_cluster_resolver import SlurmClusterResolver +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'SlurmClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..9db7f47dcb49c499719b9002b1d2d6c4837a7bd2 --- /dev/null +++ b/tensorflow/contrib/cluster_resolver/python/training/tfconfig_cluster_resolver.py @@ -0,0 +1,36 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stub file for TFConfigClusterResolver to maintain backwards compatibility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. + +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tfconfig_cluster_resolver import TFConfigClusterResolver +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'TFConfigClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) + diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index f4a8e16c99f464b813a98e981579bd0ff53bd464..3a1eaccd06e574babbe9a3232dacd1d66f3a4648 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,311 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of Cluster Resolvers for Cloud TPUs.""" +"""Stub file for TPUClusterResolver to maintain backwards compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os +# This file (and all files in this directory in general) is a backwards +# compatibility shim that exists to re-export ClusterResolvers such that +# existing OSS code will not be broken. -from six.moves.urllib.request import Request -from six.moves.urllib.request import urlopen +# pylint: disable=unused-import +from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver +# pylint: enable=unused-import -from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training import server_lib -from tensorflow.python.util import compat +from tensorflow.python.util.all_util import remove_undocumented -_GOOGLE_API_CLIENT_INSTALLED = True -try: - from googleapiclient import discovery # pylint: disable=g-import-not-at-top - from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top -except ImportError: - _GOOGLE_API_CLIENT_INSTALLED = False +_allowed_symbols = [ + 'TPUClusterResolver', +] - -_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' -_ENDPOINTS_SEPARATOR = ',' -_DEFAULT_ENV_VARIABLE = 'TPU_NAME' -_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' - - -class TPUClusterResolver(ClusterResolver): - """Cluster Resolver for Google Cloud TPUs. - - This is an implementation of cluster resolvers for the Google Cloud TPU - service. As Cloud TPUs are in alpha, you will need to specify a API definition - file for this to consume, in addition to a list of Cloud TPUs in your Google - Cloud Platform project. - """ - - def _requestComputeMetadata(self, path): - req = Request('http://metadata/computeMetadata/v1/%s' % path, - headers={'Metadata-Flavor': 'Google'}) - resp = urlopen(req) - return compat.as_bytes(resp.read()) - - def _shouldResolve(self): - if (self._tpu == compat.as_bytes('') or - self._tpu == compat.as_bytes('local') or - self._tpu.startswith(compat.as_bytes('/bns')) or - self._tpu.startswith(compat.as_bytes('localhost:')) or - self._tpu.startswith(compat.as_bytes('grpc://'))): - return False - return True - - @staticmethod - def _inGke(): - """When running in GKE, the environment variable will be set.""" - return _GKE_ENV_VARIABLE in os.environ - - @staticmethod - def _gkeEndpoints(): - return os.environ[_GKE_ENV_VARIABLE] - - @staticmethod - def _envVarFallback(): - if _DEFAULT_ENV_VARIABLE in os.environ: - return os.environ[_DEFAULT_ENV_VARIABLE] - return None - - @staticmethod - def _discoveryUrl(): - return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) - - def __init__(self, - tpu=None, - zone=None, - project=None, - job_name='worker', - coordinator_name=None, - coordinator_address=None, - credentials='default', - service=None, - discovery_url=None): - """Creates a new TPUClusterResolver object. - - The ClusterResolver will then use the parameters to query the Cloud TPU APIs - for the IP addresses and ports of each Cloud TPU listed. - - Args: - tpu: Either a string, or a list of strings corresponding to the TPUs to - use. If the single string is the empty string, the string 'local', or a - string that begins with 'grpc://' or '/bns', then it is assumed to not - correspond with a Cloud TPU and will instead be passed as the session - master and no ClusterSpec propagation will be done. - zone: Zone where the TPUs are located. If omitted or empty, we will assume - that the zone of the TPU is the same as the zone of the GCE VM, which we - will try to discover from the GCE metadata service. - project: Name of the GCP project containing Cloud TPUs. If omitted or - empty, we will try to discover the project name of the GCE VM from the - GCE metadata service. - job_name: Name of the TensorFlow job the TPUs belong to. - coordinator_name: The name to use for the coordinator. Set to None if the - coordinator should not be included in the computed ClusterSpec. - coordinator_address: The address of the coordinator (typically an ip:port - pair). If set to None, a TF server will be started. If coordinator_name - is None, a TF server will not be started even if coordinator_address is - None. - credentials: GCE Credentials. If None, then we use default credentials - from the oauth2client - service: The GCE API object returned by the googleapiclient.discovery - function. If you specify a custom service object, then the credentials - parameter will be ignored. - discovery_url: A URL template that points to the location of - the discovery service. It should have two parameters {api} and - {apiVersion} that when filled in produce an absolute URL to the - discovery document for that service. The environment variable - 'TPU_API_DISCOVERY_URL' will override this. - - Raises: - ImportError: If the googleapiclient is not installed. - ValueError: If no TPUs are specified. - """ - if isinstance(tpu, list): - if not tpu: - raise ValueError('At least one TPU must be specified.') - if len(tpu) != 1: - raise NotImplementedError( - 'Using multiple TPUs in a single session is not yet implemented') - tpu = tpu[0] - - in_gke = self._inGke() - # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None: - if in_gke: - tpu = self._gkeEndpoints() - else: - tpu = self._envVarFallback() - - if tpu is None: - raise ValueError('Please provide a TPU Name to connect to.') - - self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes - self._job_name = job_name - self._credentials = credentials - - should_resolve = self._shouldResolve() - - if not project and should_resolve: - project = compat.as_str( - self._requestComputeMetadata('project/project-id')) - - if not zone and should_resolve: - zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) - zone = zone_path.split('/')[-1] - - self._project = project - self._zone = zone - - if credentials == 'default' and should_resolve: - if _GOOGLE_API_CLIENT_INSTALLED: - self._credentials = GoogleCredentials.get_application_default() - - if service is None and should_resolve: - if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient and oauth2client must be installed ' - 'before using the TPU cluster resolver. Execute: ' - '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2client` to ' - 'install with pip.') - - final_discovery_url = self._discoveryUrl() or discovery_url - if final_discovery_url: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials, - discoveryServiceUrl=final_discovery_url) - else: - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) - else: - self._service = service - - self._coordinator_name = coordinator_name - if coordinator_name and not coordinator_address and (should_resolve or - in_gke): - self._start_local_server() - else: - self._coordinator_address = coordinator_address - - def master(self): - """Get the Master string to be used for the session. - - In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of - first instance in the ClusterSpec returned by the cluster_spec function. - - If a non-TPU name is used when constructing a TPUClusterResolver, that will - be returned instead (e.g. If the tpus argument's value when constructing - this TPUClusterResolver was 'grpc://10.240.1.2:8470', - 'grpc://10.240.1.2:8470' will be returned). - - Returns: - string, the connection string to use when creating a session. - - Raises: - ValueError: If none of the TPUs specified exists. - """ - if not self._shouldResolve(): - return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] - - job_tasks = self.cluster_spec().job_tasks(self._job_name) - if not job_tasks: - raise ValueError('No TPUs exists with the specified names exist.') - - return 'grpc://' + job_tasks[0] - - def get_master(self): - return self.master() - - def get_job_name(self): - if self._shouldResolve(): - return self._job_name - - def cluster_spec(self): - """Returns a ClusterSpec object based on the latest TPU information. - - We retrieve the information from the GCE APIs every time this method is - called. - - Returns: - A ClusterSpec containing host information returned from Cloud TPUs. - - Raises: - RuntimeError: If the provided TPU is not healthy. - """ - ############################################################################ - # There are 5 potential cases this code must handle: - # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and - # a. Create a ClusterSpec that includes the coordinator job - # b. Create a ClusterSpec without the coordinator job. - # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of - # tasks and - # a. Create a ClusterSpec with the coordinator - # b. Create a ClusterSpec without the coordinator - # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. - ############################################################################ - - if self._shouldResolve(): - # Case 1. - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, compat.as_text(self._tpu)) - request = self._service.projects().locations().nodes().get(name=full_name) - response = request.execute() - - if 'state' in response and response['state'] != 'READY': - raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % - (compat.as_text(self._tpu), response['state'])) - - if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % - (compat.as_text(self._tpu), response['health'])) - - if 'networkEndpoints' in response: - worker_list = [ - '%s:%s' % (endpoint['ipAddress'], endpoint['port']) - for endpoint in response['networkEndpoints'] - ] - else: - # Fall back to the deprecated response format - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list = [instance_url] - - cluster_spec = {self._job_name: worker_list} - else: - if not self._tpu.startswith(compat.as_bytes('grpc://')): - # Case 3. - return None - # Case 2. - cluster_spec = { - self._job_name: [ - x[len(compat.as_bytes('grpc://')):] - for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) - ] - } - - if self._coordinator_address: - # {1, 2}.a - cluster_spec[self._coordinator_name] = [self._coordinator_address] - - return server_lib.ClusterSpec(cluster_spec) - - def _start_local_server(self): - address = self._requestComputeMetadata('instance/network-interfaces/0/ip') - self._server = server_lib.Server( - { - 'local': ['0.0.0.0:0'] - }, protocol='grpc', config=None, start=True) - # self._server.target is of the form: grpc://ipaddress:port - target = compat.as_bytes(self._server.target) - splits = target.split(compat.as_bytes(':')) - assert len(splits) == 3, self._server.target - assert splits[0] == compat.as_bytes('grpc'), self._server.target - self._coordinator_port = compat.as_text(splits[2]) - self._coordinator_address = '%s:%s' % ( - address, compat.as_text(self._coordinator_port)) - - def __deepcopy__(self, memo): - # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. - return self +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py deleted file mode 100644 index ad4f6432630be44a7de6e778f55f1fb7fd66f307..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ /dev/null @@ -1,468 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TPUClusterResolver.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver -from tensorflow.python.platform import test -from tensorflow.python.training import server_lib -from tensorflow.python.util import compat - -mock = test.mock - - -class MockRequestClass(object): - - def __init__(self, name, tpu_map): - self._name = name - self._tpu_map = tpu_map - - def execute(self): - if self._name in self._tpu_map: - return self._tpu_map[self._name] - else: - raise KeyError('Resource %s was not found' % self._name) - - -class MockNodeClass(object): - - def __init__(self, tpu_map): - self._tpu_map = tpu_map - - def get(self, name): - return MockRequestClass(name, self._tpu_map) - - -def mock_request_compute_metadata(cls, *args, **kwargs): - del cls, kwargs # Unused. - if args[0] == 'project/project-id': - return 'test-project' - elif args[0] == 'instance/zone': - return 'projects/test-project/locations/us-central1-c' - elif args[0] == 'instance/network-interfaces/0/ip': - return '10.128.1.2' - return '' - - -class TPUClusterResolverTest(test.TestCase): - - def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): - """Verifies that the ClusterSpec generates the correct proto. - - We are testing this four different ways to ensure that the ClusterSpec - returned by the TPUClusterResolver behaves identically to a normal - ClusterSpec when passed into the generic ClusterSpec libraries. - - Args: - cluster_spec: ClusterSpec returned by the TPUClusterResolver - expected_proto: Expected protobuf - """ - self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec).as_cluster_def()) - self.assertProtoEquals(expected_proto, - server_lib.ClusterSpec( - cluster_spec.as_cluster_def()).as_cluster_def()) - self.assertProtoEquals(expected_proto, - server_lib.ClusterSpec( - cluster_spec.as_dict()).as_cluster_def()) - - def mock_service_client(self, tpu_map=None): - - if tpu_map is None: - tpu_map = {} - - mock_locations = mock.MagicMock() - mock_locations.nodes.return_value = MockNodeClass(tpu_map) - - mock_project = mock.MagicMock() - mock_project.locations.return_value = mock_locations - - mock_client = mock.MagicMock() - mock_client.projects.return_value = mock_project - - return mock_client - - @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', - mock_request_compute_metadata) - def testRetrieveProjectAndZoneFromMetadata(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project=None, - zone=None, - tpu=['test-tpu-1'], - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map), - coordinator_name='coordinator') - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { - name: 'coordinator' - tasks { key: 0 value: '10.128.1.2:%s' } - } - job { - name: 'worker' - tasks { key: 0 value: '10.1.2.3:8470' } - } - """ % tpu_cluster_resolver._coordinator_port - self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - - @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', - mock_request_compute_metadata) - def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project=None, - zone=None, - tpu=['test-tpu-1'], - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', - mock_request_compute_metadata) - def testUnhealthyCloudTpu(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'UNHEALTHY' - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project=None, - zone=None, - tpu='test-tpu-1', - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - with self.assertRaises(RuntimeError): - tpu_cluster_resolver.cluster_spec() - - @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', - mock_request_compute_metadata) - def testNotReadyCloudTpu(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'state': 'CREATING' - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project=None, - zone=None, - tpu='test-tpu-1', - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - with self.assertRaises(RuntimeError): - tpu_cluster_resolver.cluster_spec() - - def testSimpleSuccessfulRetrieval(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu=['test-tpu-1'], - coordinator_name='coordinator', - coordinator_address='10.128.1.5:10203', - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } - job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - def testNewNetworkEndpointFormat(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'health': 'HEALTHY', - 'networkEndpoints': [{ - 'ipAddress': '10.2.3.4', - 'port': 8470, - }] - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu='test-tpu-1', - coordinator_name='coordinator', - coordinator_address='10.128.1.5:10203', - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } - job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master()) - - @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', - mock_request_compute_metadata) - def testPodResolution(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'health': - 'HEALTHY', - 'networkEndpoints': [ - { - 'ipAddress': '10.2.3.4', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.5', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.6', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.7', - 'port': 8470, - }, - ] - } - } - - tpu_cluster_resolver = TPUClusterResolver( - tpu='test-tpu-1', - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map), - coordinator_name='coordinator') - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { - name: 'coordinator', - tasks { key: 0 value: '10.128.1.2:%s'} - } - job { - name: 'worker' - tasks { key: 0 value: '10.2.3.4:8470' } - tasks { key: 1 value: '10.2.3.5:8470' } - tasks { key: 2 value: '10.2.3.6:8470' } - tasks { key: 3 value: '10.2.3.7:8470' } - } - """ % tpu_cluster_resolver._coordinator_port - self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - - def testPodResolutionNoCoordinator(self): - tpu_map = { - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'health': - 'HEALTHY', - 'networkEndpoints': [ - { - 'ipAddress': '10.2.3.4', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.5', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.6', - 'port': 8470, - }, - { - 'ipAddress': '10.2.3.7', - 'port': 8470, - }, - ] - } - } - - tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu='test-tpu-1', - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { - name: 'worker' - tasks { key: 0 value: '10.2.3.4:8470' } - tasks { key: 1 value: '10.2.3.5:8470' } - tasks { key: 2 value: '10.2.3.6:8470' } - tasks { key: 3 value: '10.2.3.7:8470' } - } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - def testGetMasterNoEntries(self): - tpu_map = {} - - with self.assertRaises(ValueError): - TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu=[], - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - - # TODO(saeta): Convert to parameterized test when included in OSS TF. - def verifyShouldResolve(self, tpu, should_resolve): - tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu=tpu, - coordinator_name=None, - credentials=None, - service=self.mock_service_client(tpu_map={})) - self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(), - "TPU: '%s'" % tpu) - - def testShouldResolveNoName(self): - self.verifyShouldResolve('', False) - - def testShouldResolveLocal(self): - self.verifyShouldResolve('local', False) - - def testShouldResolveGrpc(self): - self.verifyShouldResolve('grpc://10.1.2.3:8470', False) - - def testShouldResolveBns(self): - self.verifyShouldResolve('/bns/foo/bar', False) - - def testShouldResolveName(self): - self.verifyShouldResolve('mytpu', True) - - def testShouldResolveList(self): - self.verifyShouldResolve(['myothertpu'], True) - - def testShouldResolveGrpcPrefix(self): - self.verifyShouldResolve('grpctpu', True) - - def testNoCallComputeMetadata(self): - tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') - self.assertEqual( - compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) - self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - - def testGkeEnvironmentForDonut(self): - os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - - self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) - self.assertTrue(TPUClusterResolver._inGke()) - self.assertEqual( - compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeEndpoints())) - - tpu_cluster_resolver = TPUClusterResolver() - self.assertEqual( - compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(tpu_cluster_resolver.master())) - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { - name: 'worker' - tasks { key: 0 value: '10.120.27.5:8470' } - } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - - def testGkeEnvironmentForPod(self): - os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' - 'grpc://10.120.27.6:8470,' - 'grpc://10.120.27.7:8470,' - 'grpc://10.120.27.8:8470') - - self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) - self.assertTrue(TPUClusterResolver._inGke()) - self.assertEqual( - compat.as_bytes('grpc://10.120.27.5:8470,' - 'grpc://10.120.27.6:8470,' - 'grpc://10.120.27.7:8470,' - 'grpc://10.120.27.8:8470'), - compat.as_bytes(TPUClusterResolver._gkeEndpoints())) - - tpu_cluster_resolver = TPUClusterResolver() - self.assertEqual( - compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(tpu_cluster_resolver.master())) - actual_cluster_spec = tpu_cluster_resolver.cluster_spec() - expected_proto = """ - job { - name: 'worker' - tasks { key: 0 value: '10.120.27.5:8470' } - tasks { key: 1 value: '10.120.27.6:8470' } - tasks { key: 2 value: '10.120.27.7:8470' } - tasks { key: 3 value: '10.120.27.8:8470' } - } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - - del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] - - def testDiscoveryUrl(self): - os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' - self.assertEqual('https://{api}.internal/{apiVersion}', - TPUClusterResolver._discoveryUrl()) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index f675c135f4fc362ea620ea5b04d6b7fd536fceaf..2ad9ae42a16f690d38b8e2652e853012ec1dd267 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -1,8 +1,18 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.5) +if(WIN32) + if(${CMAKE_VERSION} VERSION_LESS "3.8") + message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") + else() + if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64") + message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake. Ignore this if you are on CMake GUI.") + endif() + endif() +endif() + # Project -project(tensorflow C CXX) +project(tensorflow VERSION 1.12.0 LANGUAGES C CXX) # Set C++14 as standard for the whole project set(CMAKE_CXX_STANDARD 14) @@ -42,15 +52,19 @@ option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for th option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +if (WIN32) +SET(tensorflow_WIN_CPU_SIMD_OPTIONS "/arch:AVX" CACHE STRING "Enables CPU SIMD instructions") +SET_PROPERTY(CACHE tensorflow_WIN_CPU_SIMD_OPTIONS PROPERTY STRINGS /arch:AVX) +endif() + # SIMD, MKL and MKLDNN options option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) -set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") -set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) @@ -62,25 +76,30 @@ endif() if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option # for targets that link ${CMAKE_THREAD_LIBS_INIT}. - find_package (Threads) + find_package (Threads REQUIRED) # Options for linking CUDA/CUDNN libraries - option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) + option(tensorflow_PATH_CUDA_LIB "Additional library search path for cudnn, nccl, culibos" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) - option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + option(tensorflow_NCCL_INCLUDE "nccl.h header install path" /usr/include/) + if (NOT tensorflow_NCCL_INCLUDE) + # option's default value is OFF. Fill it with real default values + set(tensorflow_NCCL_INCLUDE /usr/include) + endif (NOT tensorflow_NCCL_INCLUDE) + option(tensorflow_PATH_CUDNN_LIB "Override PATH_CUDA_LIB for cudnn" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_CUDNN_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) - option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) - if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_CUDNN_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_CUDNN_LIB) + option(tensorflow_PATH_NCCL_LIB "Override PATH_CUDA_LIB for nccl" ${tensorflow_PATH_CUDA_LIB}) + if (NOT tensorflow_PATH_NCCL_LIB) # option's default value is OFF. Fill it with real default values - set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) - endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) + set (tensorflow_PATH_NCCL_LIB ${tensorflow_PATH_CUDA_LIB}) + endif (NOT tensorflow_PATH_NCCL_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -89,10 +108,12 @@ if (NOT WIN32) # Options for linking other libraries option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF) + option(systemlib_ABSEIL_CPP "Use the system installed library as shared objects instead of downloading ABSEIL_CPP and statically linking to it: ABSEIL_CPP" OFF) option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) if (systemlib_ALL) set (systemlib_ZLIB ON) + set (systemlib_ABSEIL_CPP ON) endif (systemlib_ALL) endif() @@ -114,7 +135,7 @@ function(SHOW_VARIABLES) endfunction() # External dependencies -set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external ${PROJECT_SOURCE_DIR}/modules) # Location where external projects will be downloaded set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads" @@ -183,6 +204,7 @@ if(WIN32) set(CMAKE_SUPPRESS_REGENERATION ON) endif() + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions -std=c++11") endif() @@ -198,14 +220,17 @@ endif() include(CheckCXXCompilerFlag) # OpenMP Support -CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) -if (GCC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") -endif() -CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) -if (MSVC_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") -endif() +if (WIN32) + CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) + if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") + endif() +else (WIN32) + CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) + if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + endif() +endif (WIN32) # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) @@ -235,6 +260,7 @@ include(re2) include(cub) include(sqlite) include(double_conversion) +include(abseil_cpp) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -243,6 +269,7 @@ add_definitions(${ADD_CFLAGS}) link_directories(${ADD_LINK_DIRECTORY}) set(tensorflow_EXTERNAL_LIBRARIES + ${tensorflow_EXTERNAL_LIBRARIES} ${gif_STATIC_LIBRARIES} ${png_STATIC_LIBRARIES} ${jpeg_STATIC_LIBRARIES} @@ -266,6 +293,14 @@ else (systemlib_ZLIB) ${zlib_STATIC_LIBRARIES}) endif (systemlib_ZLIB) +if (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_LIBRARIES}) +else (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${abseil_cpp_STATIC_LIBRARIES}) +endif (systemlib_ABSEIL_CPP) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -352,9 +387,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) - else (tensorflow_ENABLE_MKLDNN_SUPPORT) - add_definitions(-DINTEL_MKL_ML_ONLY) - endif() + endif(tensorflow_ENABLE_MKLDNN_SUPPORT) endif (tensorflow_ENABLE_MKL_SUPPORT) if (tensorflow_ENABLE_GPU) @@ -365,32 +398,23 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - # later command will make use of the value in tensorflow_CUDA_VERSION - find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) - - # Test compatibility of compiler on CUDA - try_compile(CUDA_TEST_COMPILE_C - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - try_compile(CUDA_TEST_COMPILE_CXX - ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda - ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc - CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) - if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) - message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + # minimum 9.0 in cuda version + find_package(CUDA 9.0 REQUIRED) + if(NOT CUDA_FOUND) + message(FATAL_ERROR "CUDA not found.") endif() - # by default we assume compute cabability 3.5 and 5.2. If you change this change it in - # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_37,code=\"sm_37,compute_37\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_52,code=\"sm_52,compute_52\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_60,code=\"sm_60,compute_60\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_61,code=\"sm_61,compute_61\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_70,code=\"sm_70,compute_70\") + # use cmake internal CUDA_ARCH_NAME switch + # e.g. CUDA_ARCH_NAME="Auto" will autodetect + # CUDA_ARCH_NAME="All" will use all arches + cuda_select_nvcc_arch_flags(NVCC_ARCH_FLAGS ${CUDA_ARCH_NAME}) + list(APPEND CUDA_NVCC_FLAGS ${NVCC_ARCH_FLAGS}) + message(STATUS "Using CUDA arch flags: ${NVCC_ARCH_FLAGS_readable}") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) + include_directories(${CUDA_INCLUDE}) if (WIN32) add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.7,5.2,6.0,6.1,7.0) @@ -411,43 +435,94 @@ if (tensorflow_ENABLE_GPU) else (WIN32) set(CUDNN_INCLUDE "${tensorflow_CUDNN_INCLUDE}") - find_library(nccl_STATIC_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT nccl_STATIC_LIBRARY) + if (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl.so PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(nccl_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT nccl_LIBRARY) message(FATAL_ERROR "NCCL is required for GPU-build") - else (NOT nccl_STATIC_LIBRARY) - message("nccl-static: ${nccl_STATIC_LIBRARY}") + else (NOT nccl_LIBRARY) + message("nccl: ${nccl_LIBRARY}") # something like /usr/lib64/libnccl_static.a - endif (NOT nccl_STATIC_LIBRARY) - - find_library(cudnn_STATIC_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT cudnn_STATIC_LIBRARY) + endif (NOT nccl_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn.so PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + find_library(cudnn_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT cudnn_LIBRARY) message(FATAL_ERROR "CUDNN is required for GPU-build") - else (NOT cudnn_STATIC_LIBRARY) - message("cudnn-static: ${cudnn_STATIC_LIBRARY}") - endif (NOT cudnn_STATIC_LIBRARY) - - find_library(culibos_STATIC_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) - if (NOT culibos_STATIC_LIBRARY) + else (NOT cudnn_LIBRARY) + file(READ ${CUDNN_INCLUDE}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + # fetch cudnn version + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MAJOR "${CUDNN_VERSION_MAJOR}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_VERSION_MINOR "${CUDNN_VERSION_MINOR}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_VERSION_PATCH "${CUDNN_VERSION_PATCH}") + if(NOT CUDNN_VERSION_MAJOR) + set(CUDNN_VERSION "???") + else() + set(CUDNN_VERSION "${CUDNN_VERSION_MAJOR}.${CUDNN_VERSION_MINOR}.${CUDNN_VERSION_PATCH}") + endif() + message(STATUS "cudnn library: ${cudnn_LIBRARY} (found version: \"${CUDNN_VERSION}\")") + endif (NOT cudnn_LIBRARY) + + if (tensorflow_BUILD_SHARED_LIB) + # shared first (if exists) else static one + find_library(culibos_LIBRARY NAMES libculibos.so libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + else (tensorflow_BUILD_SHARED_LIB) + # only static version + find_library(culibos_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_CUDA_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + endif (tensorflow_BUILD_SHARED_LIB) + if (NOT culibos_LIBRARY) message(FATAL_ERROR "CULIBOS is required for GPU-build") - else (NOT culibos_STATIC_LIBRARY) - message("culibos-static: ${culibos_STATIC_LIBRARY}") - endif (NOT culibos_STATIC_LIBRARY) + else (NOT culibos_LIBRARY) + message("culibos: ${culibos_LIBRARY}") + endif (NOT culibos_LIBRARY) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} - ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) + ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_LIBRARY} ${culibos_LIBRARY} ${nccl_LIBRARY}) endif (WIN32) include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. - string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + string(REPLACE "." "" short_CUDA_VER ${CUDA_VERSION}) + + # List of enumerated CUDA caps + string(REPLACE " " ";" NVCC_ARCH_LIST "${NVCC_ARCH_FLAGS_readable}") + set(list ${NVCC_ARCH_LIST}) + + # Construct capability string + foreach(NVCC_ARCH ${NVCC_ARCH_LIST}) + if (NVCC_ARCH MATCHES "sm_") + string(REGEX REPLACE "^.sm*" "" NVCC_ARCH ${NVCC_ARCH}) + math(EXPR NVCC_ARCH_MAJOR "${NVCC_ARCH} / 10") + math(EXPR NVCC_ARCH_MINOR "(${NVCC_ARCH} - (${NVCC_ARCH_MAJOR}*10))") + if (TF_CUDA_CAP) + set(TF_CUDA_CAP "${TF_CUDA_CAP},CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + else (TF_CUDA_CAP) + set(TF_CUDA_CAP "CudaVersion(\"${NVCC_ARCH_MAJOR}.${NVCC_ARCH_MINOR}\")") + endif (TF_CUDA_CAP) + endif() + endforeach() # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.7\"),CudaVersion(\"5.2\"),CudaVersion(\"6.0\"),CudaVersion(\"6.1\"),CudaVersion(\"7.0\")\n" + "#define TF_CUDA_CAPABILITIES ${TF_CUDA_CAP}\n" "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" - "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" + "#define TF_CUDNN_VERSION \"64_${CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -482,24 +557,30 @@ if (tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll cudart_dll_name=cudart64_${short_CUDA_VER}.dll - cuda_version_number=${tensorflow_CUDA_VERSION} + cuda_version_number=${CUDA_VERSION} nvcuda_dll_name=nvcuda.dll cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=${tensorflow_CUDA_VERSION} - cudnn_version_number=${tensorflow_CUDNN_VERSION}) + cuda_version_number=${CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) - set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value - msvcp_dll_name=msvcp140.dll) + if(WIN32) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) + else() + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu) + endif() endif(tensorflow_ENABLE_GPU) -# Find python executable -include(FindPythonInterp) -if(NOT ${PYTHONINTERP_FOUND}) - message(FATAL_ERROR "CMake was unable to find a python interpreter.") +if(tensorflow_BUILD_PYTHON_BINDINGS) + # Find python executable + include(FindPythonInterp) + if(NOT ${PYTHONINTERP_FOUND}) + message(FATAL_ERROR "CMake was unable to find a python interpreter.") + endif() endif() # Let's get to work! @@ -520,6 +601,7 @@ include(tf_cc_ops.cmake) include(tf_c.cmake) include(tf_grappler.cmake) include(tf_core_profiler.cmake) +include(tf_core_eager_runtime.cmake) if(tensorflow_BUILD_CC_EXAMPLE) include(tf_tutorials.cmake) include(tf_label_image_example.cmake) @@ -533,4 +615,4 @@ if(tensorflow_BUILD_SHARED_LIB) endif() if(tensorflow_BUILD_CC_TESTS OR tensorflow_BUILD_PYTHON_TESTS) include(tf_tests.cmake) -endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 77242b34fd8302cb9104c50a83d4141607911e7f..df8b48dfc46124d3b9454d92ffb70dbcf1bc4217 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -5,10 +5,10 @@ CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all platforms. For details, see the [TensorFlow install guide](https://www.tensorflow.org/install/). -This directory contains CMake files for building TensorFlow on Microsoft -Windows. [CMake](https://cmake.org) is a cross-platform tool that can -generate build scripts for multiple build systems, including Microsoft -Visual Studio. +This directory contains CMake files for building TensorFlow on Microsoft Windows +and Linux. [CMake](https://cmake.org) is a cross-platform tool that can generate +build scripts for multiple build systems, including Microsoft Visual Studio and +GCC. "The method has not been tested on Mac OS X. **N.B.** We provide Linux build instructions primarily for the purpose of testing the build. We recommend using the standard Bazel-based build on @@ -17,12 +17,17 @@ Linux. Current Status -------------- -CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows) -for instructions on how to install a pre-built TensorFlow package on Windows. +CMake can be used to build TensorFlow on all platforms. See the +[getting started documentation](https://www.tensorflow.org/install/install_windows) +for instructions on how to install a pre-built TensorFlow package on Windows and +Linux. The procedure in MacOS is similar to the Linux build. ### Current known limitations -* It is not possible to load a custom Op library. -* GCS file system is not supported. + +* It is not possible to load a custom Op library. +* GCS file system is not supported. +* Debug build is not available since Python for Windows is no longer + distributed with a debug library. ## Building with CMake @@ -32,70 +37,88 @@ bindings. ### Prerequisites -* CMake version 3.5 or later. +* CMake version 3.5 or later. + +* [Git](https://git-scm.com) + +* [SWIG](http://www.swig.org/download.html) + +* [Perl](https://www.perl.org/get.html) (optional, for SSL support build) + +* [Go](https://golang.org/) (optional, for SSL support build) -* [Git](https://git-scm.com) +* [NASM](http://www.nasm.us/)/[YASM](http://yasm.tortall.net/) (optional, for + SSL support build) -* [SWIG](http://www.swig.org/download.html) +* Additional pre-requisites for Microsoft Windows: -* Additional prerequisites for Microsoft Windows: - - Visual Studio 2015 - - Python 3.5 + - Visual Studio 2015 (latest version of MSVC 2017 is not supported by CUDA + yet, try it on your own risk) -* Additional prerequisites for Linux: - - Python 2.7 or later - - [Docker](https://www.docker.com/) (for automated testing) + - Python 3.5 -* Python dependencies: - - wheel - - NumPy 1.11.0 or later +* Additional prerequisites for Linux: + + - Python 2.7 or later + - [Docker](https://www.docker.com/) (for automated testing) + +* Python dependencies: + + - wheel + - NumPy 1.11.0 or later ### Known-good configurations -* Microsoft Windows 10 - - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - - [swigwin-3.0.10](http://www.swig.org/download.html) - - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) - - [NVidia CUDNN 5.1](https://developer.nvidia.com/cudnn) - - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) +* Microsoft Windows 10 + + - Microsoft Visual Studio Enterprise/ Community 2015 with Visual C++ 2015 + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) + - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) + - [swigwin-3.0.10](http://www.swig.org/download.html) + - [NVidia CUDA Toolkit 9.0](https://developer.nvidia.com/cuda-downloads) + - [NVidia CUDNN 7](https://developer.nvidia.com/cudnn) + - [CMake 3.6](https://cmake.org/files/v3.6/cmake-3.6.3-win64-x64.msi) -* Ubuntu 14.04 - - Makefile generator - - Docker 1.9.1 (for automated testing) +* Ubuntu 14.04 + + - Makefile generator + - Docker 1.9.1 (for automated testing) ### Current known limitations - - The Python package supports **Python 3.5 only**, because that is the only - version for which standard Python binaries exist and those binaries are - compatible with the TensorFlow runtime. (On Windows, the standard Python + +- The Python package supports **Python 3.5/3.6 only**, because these are the + only versions for which standard Python binaries exist and those binaries + are compatible with the TensorFlow runtime. (On Windows, the standard Python binaries for versions earlier than 3.5 were compiled with older compilers that do not have all of the features (e.g. C++11 support) needed to compile - TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 - on Windows, but have not yet committed to supporting that configuration.) - - - The following Python APIs are not currently implemented: - * Loading custom op libraries via `tf.load_op_library()`. In order to use your - custom op, please put the source code under the tensorflow/core/user_ops - directory, and a shape function is required (not optional) for each op. - * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not - functional. - - - The `tf.contrib` libraries are not currently included in the PIP package. - - - The following operations are not currently implemented: - * `DepthwiseConv2dNative` - * `Digamma` - * `Erf` - * `Erfc` - * `Igamma` - * `Igammac` - * `ImmutableConst` - * `Lgamma` - * `Polygamma` - * `Zeta` - - - Google Cloud Storage support is not currently implemented. The GCS library + TensorFlow. We welcome patches for making TensorFlow work with Python 2.7 on + Windows, but have not yet committed to supporting that configuration.) + +- The following Python APIs are not currently implemented: + + * Loading custom op libraries via `tf.load_op_library()`. In order to use + your custom op, please put the source code under the + tensorflow/core/user_ops directory, and a shape function is required + (not optional) for each op. + * Path manipulation functions (such as `tf.gfile.ListDirectory()`) are not + functional. + +- The `tf.contrib` libraries are not currently included in the PIP package. + +- The following operations are not currently implemented: + + * `DepthwiseConv2dNative` + * `Digamma` + * `Erf` + * `Erfc` + * `Igamma` + * `Igammac` + * `ImmutableConst` + * `Lgamma` + * `Polygamma` + * `Zeta` + +- Google Cloud Storage support is not currently implemented. The GCS library currently depends on `libcurl` and `boringssl`, and the Windows version could use standard Windows APIs for making HTTP requests and cryptography (for OAuth). Contributions are welcome for this feature. @@ -104,184 +127,383 @@ We are actively working on improving CMake and Windows support, and addressing these limitations. We would appreciate pull requests that implement missing ops or APIs. +# CMake GUI build (all platforms) + +Install from CMake GUI would be a convenient way to generate C++ build projects. +The software supports Windows, MacOS and Linux, while the posix platform +provides an extra ccmake binary to run command line GUI. Both working principal +of cmake, ccmake and cmake-gui are the same, the only difference is by providing +suitable interface for project configuration and dependency setting. + +1. Pre-buid checklist: The following binary/libraries should be setted in + system path, otherwise you need to set manualy via cmake. + * Compiler (GCC for Linux, MSVC for Windows) + * Make sure compiler directory has been set to system path + * CUDA 9.0 (GPU build) + * CUDNN (GPU build) + * NCCL (GPU build on Linux) + * SWIG (python binding) + * Perl (required if you need ssl support, optional) + * Go (required if you need ssl support, optional) + * NASM/YASM (required by grpc for ssl support, optional) +2. Start CMake GUI +3. Click on `Browse Source` and direct to the the folder + `/tensorflow/contrib/cmake` +4. Click on `Browse Build` and spectify a location that you want tensorflow to + be build +5. Click on `Configure`, a new window will be prompted out, specify the + generator mode for the project generation. For Windows, choose `Visual + Studio Win64`, for Linux, choose `Unix Makefiles`, then + press `Finish`. Wait for a moment, the default project dependecy would + automatically generate. +6. There are a few options that you can customize your own build. **The setting + here is crucial for a sucessful build, please check all items carefully.** + + * `tensorflow_BUILD_ALL_KERNELS` should alway be `on` + * `tensorflow_BUILD_CC_EXAMPLE` is default to be `on`. This can help you + to test build (optional) + * `tensorflow_BUILD_CONTRIB_KERNELS` is default to be `on`, but it won't + affect tensorflow function, turn it to `off` if you want a slim build. + (optional) + * `tensorflow_BUILD_PYTHON_BINDING` is default to be `on`. Set to `off` if + you don't need python interaface. If SWIG is not in system path, you + need set it manually. (optional) + * `tensorflow_BUILD_SHARED_LIB` is default to be `off`. Set to `on` if you + want the c++ interface. (optional) + * `tensorflow_ENABLE_GPU` is default to be `off`. Set to `on` if you want + GPU support. It will search CUDA and CUDNN dependecies if you have set + them to system path, otherwise CMake would prompt error and request you + to set it manually. (optional) + * `tensorflow_ENABLE_GRPC_SUPPORT` is default to be `on`. For Linux build, + this option must always be `on`. This need to be `on` for a gpu build. + Reminded that Perl, Go and NASM/YASM are required for this option if you + want to build grpc with offical SSL support. + * `tensorflow_ENABLE_POSITION_INDEPENDENT_CODE` should always be `on` + * `tensorflow_ENABLE_SNAPPY_SUPPORT` should always be `on` + * `tensorflow_OPTIMIZE_FOR_NATIVE_ARCH` should always be `on` + * `CMAKE_INSTALL_PREFIX` is the location where the final package will be + installed. You may change it to your own preferred path (optional) + +7. After changing the configuration in step 5, press `Configure` again + +8. If not error is found, press `Generate` + +#### Windows + +1. Open `tensorflow.sln` in the build folder (Windows). Change build type from + `Debug` to `Release`. Choose `Build`->`Build Solution`. This may take more + than hours of compilation. If everything is alright, the output window would + show no error. + + ##### Python + + In solution explorer, right click on `tf_python_build_pip_package` -> + `build`. It will generate the wheel file in + `/tf_python/dist`. Install with following command: + + `pip install --upgrade tensorflow-.whl` + + ***The wheel name varies depends on you config. Change to your own wheel + filename.*** + + Reminded that some pip installation requires administrator right command + prompt. + + ##### C++ + + You can directly use the build folder tree for C++ interface with cmake. If + you want to do installation for api releasing, right click on `Install` -> + `build`. The headers and library will be installed in the directory specify + by `CMAKE_INSTALL_PREFIX` during configuration. + +1. For smaller RAM computer, it is noticed that out of heap space error + appears. Change to command prompt build is an alternative to do step 1. + + Open `VS2015 x64 Native Tools Command Prompt`. You can open it by press + `Start`, then type the binary name. Use `VS2017 x64 Native Tools Command + Prompt` if you are using MSVC 2017. + + ##### Python + + Directly build python wheel package by following command: + + `MSBuild /p:Configuration=Release + ` + + Remember to change `` to the + actual path of the file, it can be found at the root of build directory + + Install the wheel file generated as instructed by step 1. + + ##### C++ interface + + Build from VS native toolchain with following command: `MSBuild + /p:Configuration=Release ` + + Headers are discretely located in the build folders. Tensorflow library can + be found at `/Release`, namely `tensorflow.dll` and + `tensorflow.lib`. + + * Build to install for api release (optional): `MSBuild + /p:Configuration=Release ` + + Remember to change `` and + `` to the actual path of the file, it can be found + at the root of build directory. + +#### Linux/MacOS (command line GNU build) + +1. Open the terminal, change working directory to the one specified in step 3. + +2. Type the following command: + + `make -sj all` + + ##### Python + + **Important Note** CMake generated python wheel for Linux/MacOs is currently + under development. Please use bazel build. + + Follow code is an expected Linux/MacOS python package build after + development work is completed. + + ``` + make -sj tf_python_build_pip_package + cd tf_python + pip install --upgrade tensorflow-.whl + ``` + + ##### C++ interface + + `make -sj install` + + Where `` is the threads used for the compilation, change + to any integer less or equal to your computer's maxiumum thread number. + + Headers are discretely located in the build folders. Tensorflow library can + be found at ``, namely `tensorflow.so` (Linux) or + `tensorflow.dylib` (MacOS). + +#### Start a Tensorflow C++ project with CMake + +Here we assume that you have basic knowledge on gathering dependency with +`CMakeLists.txt`. Here we introduce how the C++ api works with +[official hello world tutorial](https://www.tensorflow.org/api_guides/cc/guide). + +1. Create a new working directory and create a new text file named + `CMakeLists.txt` and the c++ file `main.cxx` +2. Fill in the `main.cxx` with the code provided in + [official c++ api basic](https://www.tensorflow.org/api_guides/cc/guide). +3. Fill in the `CMakeLists.txt` with following code: ``` cmake + cmake_minimum_required (VERSION 2.6) project (tf_hello) + + # Tensorflow + + find_package(Tensorflow REQUIRED) + include_directories(${TENSORFLOW_INCLUDE_DIRS}) + + # compiler setting required by tensorflow, to be tested on all compilers + + # currently only tested on MSVC and GCC + + if (${CMAKE_CXX_COMPILER_ID} STREQUAL MSVC) add_definitions(-DCOMPILER_MSVC) + elseif (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) if + (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS "3") + add_definitions(-DCOMPILER_GCC3) else() add_definitions(-D__GNUC__) endif() + else() message(ERROR " compiler ${CMAKE_CXX_COMPILER_ID} not supported by + this CMakeList.txt, under development") endif() + + add_executable(tf_hello main.cxx) target_link_libraries(tf_hello + ${TENSORFLOW_LIBRARIES}) ``` + +4. Configure the folder with cmake-gui, an error should be prompted out, + requesting you to locate the folder containing `TensorflowConfig.cmake`. + This file can be found at `` or `` (for + those have build install in previous steps). + +5. Configure again, generate the project. + +6. Compile the project with `Release` config (Windows). For Linux users, just + compile the project. + +7. Copy the `tensorflow.dll`(Windows)/`tensorflow.so`(Linux) from build + directory to the build folder containing `tf_hello` binary. + +8. Run `tf_hello` binary + +# Step-by-step Windows build (command prompt) + +1. Install the prerequisites detailed above, and set up your environment. + + * When building with GPU support after installing the CUDNN zip file from + NVidia, append its bin directory to your PATH environment variable. In + case TensorFlow fails to find the CUDA dll's during initialization, + check your PATH environment variable. It should contain the directory of + the CUDA dlls and the directory of the CUDNN dll. For example: + + ``` + D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin + D:\local\cuda\bin + ``` + + * When building with MKL support after installing + [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin + directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, + check your PATH environment variable. It should contain the directory of + the MKL dlls. For example: -Step-by-step Windows build -========================== - -1. Install the prerequisites detailed above, and set up your environment. - - * The following commands assume that you are using the Windows Command - Prompt (`cmd.exe`). You will need to set up your environment to use the - appropriate toolchain, i.e. the 64-bit tools. (Some of the binary targets - we will build are too large for the 32-bit tools, and they will fail with - out-of-memory errors.) The typical command to do set up your - environment is: - - ``` - D:\temp> "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat" - ``` - - * When building with GPU support after installing the CUDNN zip file from NVidia, append its - bin directory to your PATH environment variable. - In case TensorFlow fails to find the CUDA dll's during initialization, check your PATH environment variable. - It should contain the directory of the CUDA dlls and the directory of the CUDNN dll. - For example: - - ``` - D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin - D:\local\cuda\bin - ``` - - * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. - - In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. - It should contain the directory of the MKL dlls. For example: - - ``` - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler - D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt - ``` - - - * We assume that `cmake` and `git` are installed and in your `%PATH%`. If - for example `cmake` is not in your path and it is installed in - `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory - to your `%PATH%` as follows: - - ``` - D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" - ``` - -2. Clone the TensorFlow repository and create a working directory for your - build: - - ``` - D:\temp> git clone https://github.com/tensorflow/tensorflow.git - D:\temp> cd tensorflow\tensorflow\contrib\cmake - D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build - D:\temp\tensorflow\tensorflow\contrib\cmake> cd build - D:\temp\tensorflow\tensorflow\contrib\cmake\build> - ``` - -3. Invoke CMake to create Visual Studio solution and project files. - - **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment - variable. The other paths are for illustrative purposes only, and may - be different on your platform. The `^` character is a line continuation - and must be the last character on each line. - - ``` - D:\...\build> cmake .. -A x64 -DCMAKE_BUILD_TYPE=Release ^ - More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ - More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ - More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib - ``` - To build with GPU support add "^" at the end of the last line above following with: - ``` - More? -Dtensorflow_ENABLE_GPU=ON ^ - More? -DCUDNN_HOME="D:\...\cudnn" - ``` - To build with MKL support add "^" at the end of the last line above following with: - - ``` - More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ - More? -DMKL_HOME="D:\...\compilers_and_libraries" - ``` - - To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: - - ``` - More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX - ``` - - Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build - configuration that you choose when invoking `msbuild`. The known-good - values are `Release` and `RelWithDebInfo`. The `Debug` build type is - not currently supported, because it relies on a `Debug` library for - Python (`python35d.lib`) that is not distributed by default. - - There are various options that can be specified when generating the - solution and project files: - - * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the - `CMAKE_BUILD_TYPE` option must match the build configuration that you - choose when invoking MSBuild in step 4. The known-good values are - `Release` and `RelWithDebInfo`. The `Debug` build type is not currently - supported, because it relies on a `Debug` library for Python - (`python35d.lib`) that is not distributed by default. - - * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can - build a small subset of the kernels for a faster build by setting this - option to `OFF`. - - * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate - project files for a simple C++ - [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). - - * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. Generate - project files for building a PIP package containing the TensorFlow runtime - and its Python bindings. - - * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include - gRPC support and the distributed client and server code in the TensorFlow - runtime. - - * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include - SSL support (for making secure HTTP requests) in the TensorFlow runtime. - This support is incomplete, and will be used for Google Cloud Storage - support. - - * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include - GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1. - CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn. - - * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests. - There are many of them and building will take a few hours. - After cmake, build and execute the tests with - ``` - MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on - serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. - After building the python wheel, you need to install the new wheel before running the tests. - To execute the tests, use - ``` - ctest -C RelWithDebInfo - ``` - - * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). - CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. - - * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. - - -4. Invoke MSBuild to build TensorFlow. - - To build the C++ example program, which will be created as a `.exe` - executable in the subdirectory `.\Release`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj - D:\...\build> Release\tf_tutorials_example_trainer.exe - ``` - - To build the PIP package, which will be created as a `.whl` file in the - subdirectory `.\tf_python\dist`: - - ``` - D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj - ``` + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If + for example `cmake` is not in your path and it is installed in + `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory + to your `%PATH%` as follows: + + ``` + D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe" + ``` + +2. Clone the TensorFlow repository and create a working directory for your + build: + + ``` + D:\temp> git clone https://github.com/tensorflow/tensorflow.git + D:\temp> cd tensorflow\tensorflow\contrib\cmake + D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build + D:\temp\tensorflow\tensorflow\contrib\cmake> cd build + D:\temp\tensorflow\tensorflow\contrib\cmake\build> + ``` + +3. Invoke CMake to create Visual Studio solution and project files. + + **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment + variable. The other paths are for illustrative purposes only, and may be + different on your platform. The `^` character is a line continuation and + must be the last character on each line. + + ``` + D:\...\build> cmake .. -A x64 -Thost=x64 -DCMAKE_BUILD_TYPE=Release ^ + More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^ + More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^ + More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib + ``` + + To build with GPU support add "^" at the end of the last line above + following with: `More? -Dtensorflow_ENABLE_GPU=ON ^ More? + -DCUDNN_HOME="D:\...\cudnn"` To build with MKL support add "^" at the end of + the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + + ``` + More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX + ``` + + Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build + configuration that you choose when invoking `msbuild`. The known-good values + are `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + The `-Thost=x64` flag will ensure that the 64 bit compiler and linker is + used when building. Without this flag, MSBuild will use the 32 bit toolchain + which is prone to compile errors such as "compiler out of heap space". + + There are various options that can be specified when generating the solution + and project files: + + * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the + `CMAKE_BUILD_TYPE` option must match the build configuration that you + choose when invoking MSBuild in step 4. The known-good values are + `Release` and `RelWithDebInfo`. The `Debug` build type is not currently + supported, because it relies on a `Debug` library for Python + (`python35d.lib`) that is not distributed by default. + + * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can + build a small subset of the kernels for a faster build by setting this + option to `OFF`. + + * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate + project files for a simple C++ + [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc). + + * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. + Generate project files for building a PIP package containing the + TensorFlow runtime and its Python bindings. + + * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include + gRPC support and the distributed client and server code in the + TensorFlow runtime. + + * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + SSL support (for making secure HTTP requests) in the TensorFlow runtime. + This support is incomplete, and will be used for Google Cloud Storage + support. + + * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include GPU + support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and + CUDNN 5.1. CMake will expect the location of CUDNN in + -DCUDNN_HOME=path_you_unzipped_cudnn. + + * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds + cc unit tests. There are many of them and building will take a few + hours. After cmake, build and execute the tests with `MSBuild + /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj ctest -C + RelWithDebInfo` + + * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python kernel tests. After building the python wheel, you need + to install the new wheel before running the tests. To execute the tests, + use `ctest -C RelWithDebInfo` + + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This + enables python tests on serveral major packages. This option is only + valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. + After building the python wheel, you need to install the new wheel + before running the tests. To execute the tests, use `ctest -C + RelWithDebInfo` + + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include + MKL support. If MKL is enabled you need to install the + [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). CMake + will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. + Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for + Deep Neural Networks (Intel(R) + MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add + `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + +4. Invoke MSBuild to build TensorFlow. + + Set up the path to find MSbuild: `D:\temp> "C:\Program Files (x86)\Microsoft + Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"` + + To build the C++ example program, which will be created as a `.exe` + executable in the subdirectory `.\Release`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj + D:\...\build> Release\tf_tutorials_example_trainer.exe + ``` + + To build the PIP package, which will be created as a `.whl` file in the + subdirectory `.\tf_python\dist`: + + ``` + D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj + ``` Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/TensorflowConfig.cmake.in b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..cc04db6e952f53b8bb5416dde60b8173e60bf60e --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfig.cmake.in @@ -0,0 +1,16 @@ +# - Config file for the Tensorflow package +# It defines the following variables +# TENSORFLOW_INCLUDE_DIRS - include directories for FooBar +# TENSORFLOW_LIBRARIES - libraries to link against + +# Compute paths +get_filename_component(TENSORFLOW_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +set(TENSORFLOW_INCLUDE_DIRS "@CONF_INCLUDE_DIRS@") + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TENSORFLOW_BINARY_DIR) + include("${TENSORFLOW_CMAKE_DIR}/TensorflowTargets.cmake") +endif() + +# These are IMPORTED targets created by TensorflowTargets.cmake +set(TENSORFLOW_LIBRARIES tensorflow) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2a9609ddb9c4ca864651818bdfae0f8fe290de31 --- /dev/null +++ b/tensorflow/contrib/cmake/TensorflowConfigVersion.cmake.in @@ -0,0 +1,11 @@ +set(PACKAGE_VERSION "@TENSORFLOW_VERSION@") + +# Check whether the requested PACKAGE_FIND_VERSION is compatible +if("${PACKAGE_VERSION}" VERSION_LESS "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + set(PACKAGE_VERSION_COMPATIBLE TRUE) + if ("${PACKAGE_VERSION}" VERSION_EQUAL "${PACKAGE_FIND_VERSION}") + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/abseil_cpp.cmake b/tensorflow/contrib/cmake/external/abseil_cpp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..46a193971c5084523d432065f265fa7a9909f595 --- /dev/null +++ b/tensorflow/contrib/cmake/external/abseil_cpp.cmake @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +if (systemlib_ABSEIL_CPP) + + find_package(AbseilCpp REQUIRED + absl_base + absl_spinlock_wait + absl_dynamic_annotations + absl_malloc_internal + absl_throw_delegate + absl_int128 + absl_strings + str_format_internal + absl_bad_optional_access) + + include_directories(${ABSEIL_CPP_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${ABSEIL_CPP_LIBRARIES}) + + message(STATUS " abseil_cpp includes: ${ABSEIL_CPP_INCLUDE_DIR}") + message(STATUS " abseil_cpp libraries: ${ABSEIL_CPP_LIBRARIES}") + + add_custom_target(abseil_cpp) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + +else (systemlib_ABSEIL_CPP) + + include (ExternalProject) + + set(abseil_cpp_INCLUDE_DIR ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp) + set(abseil_cpp_URL https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz) + set(abseil_cpp_HASH SHA256=84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e) + set(abseil_cpp_BUILD ${CMAKE_BINARY_DIR}/abseil_cpp/src/abseil_cpp-build) + + if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/Release/absl_base.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_dynamic_annotations.lib + ${abseil_cpp_BUILD}/absl/base/Release/absl_internal_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/strings/Release/absl_strings.lib + ${abseil_cpp_BUILD}/absl/strings/Release/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/types/Release/absl_bad_optional_access.lib) + else() + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/absl_base.lib + ${abseil_cpp_BUILD}/absl/base/absl_spinlock_wait.lib + ${abseil_cpp_BUILD}/absl/base/absl_dynamic_annotations.lib + ${abseil_cpp_BUILD}/absl/base/absl_malloc_internal.lib + ${abseil_cpp_BUILD}/absl/base/absl_throw_delegate.lib + ${abseil_cpp_BUILD}/absl/numeric/absl_int128.lib + ${abseil_cpp_BUILD}/absl/strings/absl_strings.lib + ${abseil_cpp_BUILD}/absl/strings/str_format_internal.lib + ${abseil_cpp_BUILD}/absl/types/absl_bad_optional_access.lib) + endif() + else() + set(abseil_cpp_STATIC_LIBRARIES + ${abseil_cpp_BUILD}/absl/base/libabsl_base.a + ${abseil_cpp_BUILD}/absl/base/libabsl_spinlock_wait.a + ${abseil_cpp_BUILD}/absl/base/libabsl_dynamic_annotations.a + ${abseil_cpp_BUILD}/absl/base/libabsl_malloc_internal.a + ${abseil_cpp_BUILD}/absl/base/libabsl_throw_delegate.a + ${abseil_cpp_BUILD}/absl/numeric/libabsl_int128.a + ${abseil_cpp_BUILD}/absl/strings/libabsl_strings.a + ${abseil_cpp_BUILD}/absl/strings/libstr_format_internal.a + ${abseil_cpp_BUILD}/absl/types/libabsl_bad_optional_access.a) + endif() + + ExternalProject_Add(abseil_cpp + PREFIX abseil_cpp + URL ${abseil_cpp_URL} + URL_HASH ${abseil_cpp_HASH} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_BYPRODUCTS ${abseil_cpp_STATIC_LIBRARIES} + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + ) + + include_directories(${abseil_cpp_INCLUDE_DIR}) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${abseil_cpp_STATIC_LIBRARIES}) + + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES abseil_cpp) + +endif (systemlib_ABSEIL_CPP) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index b1e64aa55c80ad59cfdc0f4767c0282b4f73367f..e570c09ecb5e64130ed6f3375a51d74850cc3989 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) +set(GRPC_TAG 69b6c047bc767b4d80e7af4d00ccb7c45b683dae) if(WIN32) # We use unsecure gRPC because boringssl does not build on windows @@ -26,9 +26,9 @@ if(WIN32) set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/$(Configuration)/gpr.lib) else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc++_unsecure.lib @@ -43,8 +43,9 @@ else() ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/zlib/libz.a) endif() add_definitions(-DGRPC_ARES=0) @@ -66,7 +67,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b..32e6d78e508e25f76bd263e9d52b6574ca315f6c 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -59,6 +59,7 @@ ExternalProject_Add(png -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} + -DPNG_TESTS:BOOL=OFF ) ## put png includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index 56a57a2340ddc7f923c611c222a0399e279ad58a..773c37b309b1dff4ed28d24cd7d6140a63ec5bc6 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,18 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG v3.6.1) + +# enable choose protobuf versions +SET(PROTOBUF_VERSION "3.6.1" CACHE STRING "Protobuf version") +SET_PROPERTY(CACHE PROTOBUF_VERSION PROPERTY STRINGS "3.4.0" "3.5.0" "3.6.1") + +if(${PROTOBUF_VERSION} STREQUAL "3.5.1") + set(PROTOBUF_TAG v3.6.1) +elseif(${PROTOBUF_VERSION} STREQUAL "3.5.0") + set(PROTOBUF_TAG 2761122b810fe8861004ae785cc3ab39f384d342) +elseif(${PROTOBUF_VERSION} STREQUAL "3.4.0") + set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +endif() if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..944ae3997a9489c13f65f93d9a7e61c21dd975c1 --- /dev/null +++ b/tensorflow/contrib/cmake/modules/FindAbseilCpp.cmake @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +find_path(ABSEIL_CPP_INCLUDE_DIR absl/base/config.h + HINTS "${ABSEIL_CPP_INCLUDE_DIR_HINTS}" + PATHS "$ENV{PROGRAMFILES}" + "$ENV{PROGRAMW6432}" + PATH_SUFFIXES "") + +if(EXISTS "${ABSEIL_CPP_INCLUDE_DIR}" AND NOT "${ABSEIL_CPP_INCLUDE_DIR}" STREQUAL "") + + if(NOT AbseilCpp_FIND_COMPONENTS) + # search all libraries if no COMPONENTS was requested + set(AbseilCpp_FIND_COMPONENTS + "absl_algorithm;absl_any;absl_bad_any_cast" + "absl_bad_optional_access;absl_base;absl_container;absl_debugging" + "absl_dynamic_annotations;absl_examine_stack;absl_failure_signal_handler" + "absl_int128;absl_leak_check;absl_internal_malloc_internal;absl_memory;absl_meta" + "absl_numeric;absl_optional;absl_span;absl_internal_spinlock_wait;absl_stack_consumption" + "absl_stacktrace;absl_str_format;absl_strings;absl_symbolize;absl_synchronization" + "absl_throw_delegate;absl_time;absl_utility;str_format_extension_internal" + "str_format_internal;test_instance_tracker_lib") + endif() + + foreach(LIBNAME ${AbseilCpp_FIND_COMPONENTS}) + + unset(ABSEIL_CPP_LIBRARY CACHE) + + find_library(ABSEIL_CPP_LIBRARY + NAMES ${LIBNAME} + HINTS ${ABSEIL_CPP_LIBRARIES_DIR_HINTS}) + + if(ABSEIL_CPP_LIBRARY) + list(APPEND ABSEIL_CPP_LIBRARIES ${ABSEIL_CPP_LIBRARY}) + else() + message(FATAL_ERROR "\n" + "abseil_cpp library \"${LIBNAME}\" not found in system path.\n" + "Please provide locations using: -DABSEIL_CPP_LIBRARIES_DIR_HINTS:STRING=\"PATH\"\n") + endif() + + endforeach() + + unset(LIBNAME CACHE) + unset(ABSEIL_CPP_LIBRARY CACHE) + + set(ABSEIL_CPP_FOUND TRUE) + message(STATUS "Found abseil_cpp libraries") + + set(ABSEIL_CPP_INCLUDE_DIR "${ABSEIL_CPP_INCLUDE_DIR}" CACHE PATH "" FORCE) + mark_as_advanced(ABSEIL_CPP_INCLUDE_DIR) + + set(ABSEIL_CPP_LIBRARIES "${ABSEIL_CPP_LIBRARIES}" CACHE PATH "" FORCE) + mark_as_advanced(ABSEIL_CPP_LIBRARIES) + +else() + + message(FATAL_ERROR "\n" + "abseil_cpp headers not found in system path.\n" + "Please provide locations using: -DABSEIL_CPP_INCLUDE_DIR_HINTS:STRING=\"PATH\"\n") + +endif() diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 6e72670142d560a364350bb4769f1153f884b0f6..96160568fa79291a7b391761373e1eaf0f70974e 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -57,6 +57,7 @@ tensorflow/python/ops tensorflow/python/ops/distributions tensorflow/python/ops/linalg tensorflow/python/ops/losses +tensorflow/python/ops/signal tensorflow/python/platform tensorflow/python/profiler tensorflow/python/profiler/internal @@ -279,10 +280,10 @@ tensorflow/contrib/linear_optimizer/kernels/g3doc tensorflow/contrib/linear_optimizer/python tensorflow/contrib/linear_optimizer/python/ops # TODO(drpngx): Fix failing imports -# tensorflow/contrib/lite -# tensorflow/contrib/lite/python -# tensorflow/contrib/lite/toco -# tensorflow/contrib/lite/toco/python +# tensorflow/lite +# tensorflow/lite/python +# tensorflow/lite/toco +# tensorflow/lite/toco/python tensorflow/contrib/lookup tensorflow/contrib/losses tensorflow/contrib/losses/python @@ -308,11 +309,6 @@ tensorflow/contrib/model_pruning/examples tensorflow/contrib/model_pruning/examples/cifar10 tensorflow/contrib/model_pruning/python tensorflow/contrib/model_pruning/python/layers -tensorflow/contrib/nccl -tensorflow/contrib/nccl/kernels -tensorflow/contrib/nccl/ops -tensorflow/contrib/nccl/python -tensorflow/contrib/nccl/python/ops tensorflow/contrib/nearest_neighbor tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops @@ -382,8 +378,6 @@ tensorflow/contrib/seq2seq/python/ops tensorflow/contrib/session_bundle tensorflow/contrib/session_bundle/example tensorflow/contrib/signal -tensorflow/contrib/signal/python -tensorflow/contrib/signal/python/ops tensorflow/contrib/slim tensorflow/contrib/slim/python tensorflow/contrib/slim/python/slim diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 42afbd9105ef3789430606d909979ca308e2eaa8..013180c89083748b240ad061b342300e886d3568 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -6,7 +6,7 @@ tensorflow/contrib/boosted_trees/proto tensorflow/contrib/cloud/kernels tensorflow/contrib/decision_trees/proto tensorflow/contrib/gdr -tensorflow/contrib/lite/toco +tensorflow/lite/toco tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 7a30eb94f54b18a2a517615a315e23e09e1170d0..a04142bd249ed5e16beba11057d0efc1e191e31b 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + ######################################################## # tf_c_framework library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6c90cf398c69c8c1b22ea75e0c407f258e2535f9..6514ae50a4a35b35ba100af6997079294c22f9b8 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -149,11 +149,7 @@ add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) if (WIN32) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") - else() - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") - endif() + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib") else (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index a54cbff33b66d63d7229fa2f50b8a4ca962111ed..d8884d464fb5974d77506561a9ed36110a3804c0 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -39,6 +39,8 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/*test*.h" "${tensorflow_source_dir}/tensorflow/core/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/*main.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc" diff --git a/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake new file mode 100644 index 0000000000000000000000000000000000000000..78e4c0d3035cdaefa1d0950f4270d60152c805af --- /dev/null +++ b/tensorflow/contrib/cmake/tf_core_eager_runtime.cmake @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +######################################################## +# tf_core_eager_runtime library +######################################################## +file(GLOB_RECURSE tf_core_eager_runtime_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*.h" +) + +file(GLOB_RECURSE tf_core_eager_runtime_exclude_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_core_eager_runtime_srcs ${tf_core_eager_runtime_exclude_srcs}) + +add_library(tf_core_eager_runtime OBJECT ${tf_core_eager_runtime_srcs}) +add_dependencies( + tf_core_eager_runtime + tf_c + tf_core_lib) + + +file(GLOB_RECURSE tf_c_eager_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/*.h" +) + +file(GLOB_RECURSE tf_c_eager_exlclude_srcs + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.h" + "${tensorflow_source_dir}/tensorflow/c/eager/*test*.cc" +) + +list(REMOVE_ITEM tf_c_eager_srcs ${tf_c_eager_exlclude_srcs}) + +add_library(tf_c_eager OBJECT ${tf_c_eager_srcs}) +add_dependencies( + tf_c_eager + tf_core_eager_runtime + tf_c + tf_cc_framework + tf_cc_while_loop + tf_core_lib + tf_protos_cc) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 7e806685b8448cbd629985cdc00ed1193857abe6..d7b2a1339e047aba0a9424a53a63726805e89721 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -140,16 +140,19 @@ set(tf_proto_text_srcs "tensorflow/core/example/example.proto" "tensorflow/core/example/feature.proto" "tensorflow/core/framework/allocation_description.proto" + "tensorflow/core/framework/api_def.proto" "tensorflow/core/framework/attr_value.proto" "tensorflow/core/framework/cost_graph.proto" "tensorflow/core/framework/device_attributes.proto" "tensorflow/core/framework/function.proto" "tensorflow/core/framework/graph.proto" "tensorflow/core/framework/graph_transfer_info.proto" + "tensorflow/core/framework/iterator.proto" "tensorflow/core/framework/kernel_def.proto" "tensorflow/core/framework/log_memory.proto" "tensorflow/core/framework/node_def.proto" "tensorflow/core/framework/op_def.proto" + "tensorflow/core/framework/reader_base.proto" "tensorflow/core/framework/remote_fused_graph_execute_info.proto" "tensorflow/core/framework/resource_handle.proto" "tensorflow/core/framework/step_stats.proto" @@ -159,6 +162,7 @@ set(tf_proto_text_srcs "tensorflow/core/framework/tensor_shape.proto" "tensorflow/core/framework/tensor_slice.proto" "tensorflow/core/framework/types.proto" + "tensorflow/core/framework/variable.proto" "tensorflow/core/framework/versions.proto" "tensorflow/core/lib/core/error_codes.proto" "tensorflow/core/protobuf/cluster.proto" @@ -204,10 +208,10 @@ file(GLOB tf_core_platform_srcs "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.h" "${tensorflow_source_dir}/tensorflow/core/framework/resource_handle.cc") if (NOT tensorflow_ENABLE_GPU) - file(GLOB tf_core_platform_gpu_srcs + file(GLOB tf_core_platform_gpu_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/cuda_libdevice_path.*" "${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*") - list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) + list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs_exclude}) else() file(GLOB tf_core_platform_srcs_exclude "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 7b892ba248bc43cd885f295288c677ac97efaa06..d66e39ac07c7b7c9423fa7e878a9cefd94b867bd 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,14 +68,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" @@ -97,9 +89,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/libsvm/ops/libsvm_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" - "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/resampler/kernels/resampler_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index bc753333dba4f67eee0114c4022743dd59a05982..310eed4ecbfdd30a3b3bdd4728c030fe70930797 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names - "audio_ops" "array_ops" + "audio_ops" "batch_ops" "bitwise_ops" "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" + "collective_ops" "control_flow_ops" "ctc_ops" "cudnn_rnn_ops" @@ -27,13 +28,14 @@ set(tf_op_lib_names "dataset_ops" "decode_proto_ops" "encode_proto_ops" + "function_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" "list_ops" - "lookup_ops" "logging_ops" + "lookup_ops" "manip_ops" "math_ops" "nn_ops" @@ -43,10 +45,11 @@ set(tf_op_lib_names "remote_fused_graph_ops" "resource_variable_ops" "rpc_ops" + "scoped_allocator_ops" "script_ops" "sdca_ops" - "set_ops" "sendrecv_ops" + "set_ops" "sparse_ops" "spectral_ops" "state_ops" @@ -54,6 +57,7 @@ set(tf_op_lib_names "string_ops" "summary_ops" "training_ops" + "word2vec_ops" ) foreach(tf_op_lib_name ${tf_op_lib_names}) @@ -89,7 +93,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") @@ -99,7 +102,6 @@ GENERATE_CONTRIB_OP_LIBRARY(image_distort_image "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 6d86daf5f174a3238ab92e5bba6085c904766766..8faccf8d55902e6701ebb4ce534b84705304fd5f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -222,17 +222,17 @@ endforeach(python_module) add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite") add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python") add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python/__init__.py") add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/lite/python/lite.py) # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -313,15 +313,14 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() -GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("boosted_trees_ops") -GENERATE_PYTHON_OP_LIB("math_ops") -GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") GENERATE_PYTHON_OP_LIB("checkpoint_ops") +GENERATE_PYTHON_OP_LIB("collective_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") @@ -332,14 +331,18 @@ GENERATE_PYTHON_OP_LIB("decode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) GENERATE_PYTHON_OP_LIB("encode_proto_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) +GENERATE_PYTHON_OP_LIB("function_ops") +GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") -GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("manip_ops") +GENERATE_PYTHON_OP_LIB("math_ops") +GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("no_op") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -347,17 +350,21 @@ GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" GENERATE_PYTHON_OP_LIB("resource_variable_ops") GENERATE_PYTHON_OP_LIB("rpc_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) +GENERATE_PYTHON_OP_LIB("scoped_allocator_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") +GENERATE_PYTHON_OP_LIB("sendrecv_ops") GENERATE_PYTHON_OP_LIB("set_ops") -GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") +GENERATE_PYTHON_OP_LIB("state_ops") +GENERATE_PYTHON_OP_LIB("stateless_random_ops") GENERATE_PYTHON_OP_LIB("string_ops") GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("word2vec_ops") GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_model_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_model_ops.py) @@ -373,8 +380,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" @@ -393,11 +398,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) - GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -422,8 +424,6 @@ GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) -GENERATE_PYTHON_OP_LIB("stateless_random_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) @@ -526,11 +526,13 @@ if(WIN32) add_library(pywrap_tensorflow_internal_static STATIC ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -583,11 +585,13 @@ endif(WIN32) add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_internal_src} $ + $ $ $ $ $ $ + $ $ $ $ @@ -617,13 +621,28 @@ target_include_directories(pywrap_tensorflow_internal PUBLIC ${NUMPY_INCLUDE_DIR} ) -target_link_libraries(pywrap_tensorflow_internal PRIVATE +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + # There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when + # linking to the tensorflow library. Adding the following libraries fixes it. + # See issue on github: https://github.com/tensorflow/tensorflow/issues/9593 + target_link_libraries(pywrap_tensorflow_internal PRIVATE ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} tf_protos_cc tf_python_protos_cc ${PYTHON_LIBRARIES} + gcc_s + gcc ) +else() + target_link_libraries(pywrap_tensorflow_internal PRIVATE + ${tf_core_gpu_kernels_lib} + ${tensorflow_EXTERNAL_LIBRARIES} + tf_protos_cc + tf_python_protos_cc + ${PYTHON_LIBRARIES} +) +endif() if(WIN32) @@ -808,10 +827,10 @@ add_dependencies(tf_python_api tf_python_ops) ######################################################## # Parse tensorflow/python/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_gen.bzl api_generator_BUILD_text) -STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) -string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) -string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +FILE(READ ${tensorflow_source_dir}/tensorflow/python/tools/api/generator/api_init_files.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) set(api_init_files "") diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index fdf522f1fd90ffc64acbe82381ef57a389645d61..62005dd113bfb80fbdf23afb6d4aa5f90a1e32de 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -23,6 +23,8 @@ if(WIN32) # we need. # add_library(tensorflow_static STATIC + $ + $ $ $ $ @@ -65,6 +67,8 @@ endif(WIN32) # tensorflow is a shared library containing all of the # TensorFlow runtime and the standard ops and kernels. add_library(tensorflow SHARED + $ + $ $ $ $ @@ -96,6 +100,27 @@ if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) target_link_libraries(tensorflow PRIVATE gcc_s gcc) endif() +# Offer the user the choice of overriding the installation directories +set(INSTALL_LIB_DIR lib CACHE PATH "Installation directory for libraries") +set(INSTALL_BIN_DIR bin CACHE PATH "Installation directory for executables") +set(INSTALL_INCLUDE_DIR include CACHE PATH + "Installation directory for header files") +if(WIN32 AND NOT CYGWIN) + set(DEF_INSTALL_CMAKE_DIR cmake) +else() + set(DEF_INSTALL_CMAKE_DIR lib/cmake) +endif() +set(INSTALL_CMAKE_DIR ${DEF_INSTALL_CMAKE_DIR} CACHE PATH + "Installation directory for CMake files") + +# Make relative paths absolute (needed later on) +foreach(p LIB BIN INCLUDE CMAKE) + set(var INSTALL_${p}_DIR) + if(NOT IS_ABSOLUTE "${${var}}") + set(${var} "${CMAKE_INSTALL_PREFIX}/${${var}}") + endif() +endforeach() + if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) @@ -103,14 +128,57 @@ endif(WIN32) target_include_directories(tensorflow PUBLIC $) -install(TARGETS tensorflow EXPORT tensorflow_export - RUNTIME DESTINATION bin - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib) +# Add all targets to build-tree export set +export(TARGETS tensorflow + FILE ${PROJECT_BINARY_DIR}/TensorflowTargets.cmake) + +# Export the package for use from the build-tree +export(PACKAGE Tensorflow) + +# Create the TensorflowConfig.cmake and TensorflowConfigVersion files +file(RELATIVE_PATH REL_INCLUDE_DIR "${INSTALL_CMAKE_DIR}" + "${INSTALL_INCLUDE_DIR}") +# for the build tree +set(CONF_INCLUDE_DIRS "${tensorflow_source_dir}" + "${PROJECT_BINARY_DIR}" + "${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src" + "${CMAKE_CURRENT_BINARY_DIR}/nsync/install/include" # Please if there is a better directory + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/" + "${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/" + "${tensorflow_source_dir}/third_party/eigen3/" + "${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfig.cmake" @ONLY) +# for the install tree, yet to be complete +set(CONF_INCLUDE_DIRS "\${TENSORFLOW_CMAKE_DIR}/${REL_INCLUDE_DIR}") +configure_file(TensorflowConfig.cmake.in + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" @ONLY) +# for both +configure_file(TensorflowConfigVersion.cmake.in + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" @ONLY) + +# install(TARGETS tensorflow EXPORT tensorflow_export +# RUNTIME DESTINATION ${INSTALL_BIN_DIR} +# LIBRARY DESTINATION ${INSTALL_LIB_DIR} +# ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) + +# install(EXPORT tensorflow_export +# FILE TensorflowConfig.cmake +# DESTINATION ${INSTALL_CMAKE_DIR}) -install(EXPORT tensorflow_export - FILE TensorflowConfig.cmake - DESTINATION lib/cmake) +install(FILES + "${PROJECT_BINARY_DIR}/${CMAKE_FILES_DIRECTORY}/TensorflowConfig.cmake" + "${PROJECT_BINARY_DIR}/TensorflowConfigVersion.cmake" + DESTINATION "${INSTALL_CMAKE_DIR}" COMPONENT dev) + +# install the export set for use with the install-tree +install(EXPORT TensorflowTargets + DESTINATION ${INSTALL_CMAKE_DIR}) + +install(TARGETS tensorflow EXPORT TensorflowTargets + RUNTIME DESTINATION ${INSTALL_BIN_DIR} + LIBRARY DESTINATION ${INSTALL_LIB_DIR} + ARCHIVE DESTINATION ${INSTALL_LIB_DIR}) # install necessary headers # tensorflow headers @@ -145,6 +213,10 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# absl directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/abseil_cpp/src/abseil_cpp/absl/ + DESTINATION include/absl + FILES_MATCHING PATTERN "*.h") # mkl if (tensorflow_ENABLE_MKL_SUPPORT) install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index 4bfd753bb1d1fc254c66a4f7eb1d6ac83a40cb70..7f96a103d4cd797bc733a41a673eac492419b4c6 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -13,12 +13,12 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library", - "tf_custom_op_py_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", "tf_py_test", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") cc_library( name = "range_coder", diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index f83386b8a4246ff2d7acdd2190804296582ee945..e4566437c60ebb2da039e61c171fbe954a7355c9 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -7,6 +7,7 @@ package_group( includes = ["//tensorflow/compiler/jit:friends"], packages = [ "//tensorflow/...", + "//tensorflow_models/...", "//third_party/py/tensor2tensor/...", ], ) @@ -57,7 +58,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/compiler/jit:xla_ops_py", - "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/compiler/jit/ops:xla_ops_grad", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -80,6 +81,7 @@ tf_py_test( "//tensorflow/python:control_flow_util", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py index 873b03580d6f1d9cb25c79cb31989d43cdb8c9a7..f867cd15b67dbd43650d8012b4299845af7200a8 100644 --- a/tensorflow/contrib/compiler/xla.py +++ b/tensorflow/contrib/compiler/xla.py @@ -23,7 +23,7 @@ import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.jit.ops import xla_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function +from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops @@ -35,6 +35,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import function_utils from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect _XLA_COMPILE_ATTR = '_xla_compile_id' _MAX_WARNING_LINES = 5 @@ -179,14 +180,11 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext): if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_control_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_control_inputs - if x.outputs - ] - self.Exit() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] # pylint: disable=protected-access op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access @@ -266,13 +264,13 @@ def _compile_internal(computation, inputs=None): inputs = [ops.convert_to_tensor(x) for x in inputs] input_arity = len(inputs) - arg_error = tpu_function.check_function_argument_count( + arg_error = check_function_argument_count( computation, input_arity, infeed_queue=None) if arg_error is not None: raise TypeError( 'Supplied computation cannot be called with the specified inputs. You ' 'specified %d inputs: %s, but the computation needs %s' % - (input_arity, str([i.name for i in inputs[0]]), arg_error)) + (input_arity, str([i.name for i in inputs]), arg_error)) cluster_name = ops.get_default_graph().unique_name('cluster') pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') @@ -606,8 +604,8 @@ class _ModelFnWrapper(object): def estimator_model_fn(target_model_fn=None): """estimator_model_fn decorates a model_fn to be compiled for execution. - Currently only it only works with `TPUEstimator`. If you need to use it with - base `Estimator`, please add `tf.enable_resource_variables()` at beginning of + Currently it only works with `TPUEstimator`. If you need to use it with base + `Estimator`, please add `tf.enable_resource_variables()` at the beginning of your program. Example 1, decorating model_fn: @@ -645,3 +643,51 @@ def estimator_model_fn(target_model_fn=None): return tf_decorator.make_decorator(function, _ModelFnWrapper(function)) return decorated(target_model_fn) if target_model_fn else decorated + + +def check_function_argument_count(func, input_arity, infeed_queue): + """Validate the number of input arguments to an XLA function. + + Args: + func: the Python function that will be called to generate the body of an XLA + computation graph. + input_arity: the number of explicit arguments supplied by the caller. + infeed_queue: if not None, the infeed queue that will supply + additional arguments to the function. + + Returns: + None if function can be called with the supplied number of + arguments, or an error string if it cannot. + """ + def format_error(complaint, quantity): + return '%s %d argument%s' % (complaint, quantity, '' + if quantity == 1 else 's') + + num_args_supplied = input_arity + if infeed_queue is not None: + num_args_supplied += infeed_queue.number_of_tuple_elements + arg_spec = tf_inspect.getargspec(func) + num_func_args = len(arg_spec.args) + if arg_spec.defaults is None: + num_func_defaults = 0 + else: + num_func_defaults = len(arg_spec.defaults) + min_func_args = num_func_args - num_func_defaults + if num_args_supplied < min_func_args: + # The required number of arguments is not enough to call the function. + if num_func_defaults == 0 and arg_spec.varargs is None: + return format_error('exactly', num_func_args) + else: + return format_error('at least', min_func_args) + if arg_spec.varargs is None and num_args_supplied > num_func_args: + # The required number of arguments is too many to call the function. + if num_func_defaults == 0: + return format_error('exactly', num_func_args) + else: + return format_error('at most', num_func_args) + # Reaching here means either + # 1) There are varargs, func can accept any number of arguments greater than + # the minimum. + # 2) Number of supplied arguments falls in range of acceptable argument count + # of func. + return None diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py index a306b56f63bd3b135b0231da89fb2e3445570740..3b49755afcf0753d31c0ce506dce42709b1ee8bc 100644 --- a/tensorflow/contrib/compiler/xla_test.py +++ b/tensorflow/contrib/compiler/xla_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.compiler import xla +from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.python import summary from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -27,7 +28,6 @@ from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -48,7 +48,7 @@ class XLACompileContextTest(test.TestCase): histogram_summary = summary.histogram('histogram_summary', dummy_tensor) image_summary = summary.image('image_summary', dummy_tensor) scalar_summary = summary.scalar('scalar_summary', dummy_tensor) - tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor) summary.merge( [ audio_summary, histogram_summary, image_summary, scalar_summary, @@ -176,5 +176,81 @@ class XLACompileContextTest(test.TestCase): self.assertFalse(op.graph.is_fetchable(op.op)) +class CheckFunctionArgumentCountTest(test.TestCase): + + def testSimple(self): + """Tests that arg checker works for functions with no varargs or defaults. + """ + + def func(x, y, z): + return x + y + z + + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual('exactly 3 arguments', + xla.check_function_argument_count(func, 2, None)) + queue = tpu_feed.InfeedQueue(2) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual('exactly 3 arguments', + xla.check_function_argument_count(func, 2, queue)) + + def testDefaultArgs(self): + """Tests that arg checker works for a function with no varargs.""" + + def func(x, y, z=17): + return x + y + z + + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + self.assertEqual('at most 3 arguments', + xla.check_function_argument_count(func, 4, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + self.assertEqual('at most 3 arguments', + xla.check_function_argument_count(func, 4, queue)) + + def testVarArgs(self): + """Tests that arg checker works for a function with varargs.""" + + def func(x, y, *z): + return x + y + len(z) + + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + + def testVarArgsAndDefaults(self): + """Tests that arg checker works for a function with varargs and defaults.""" + + def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg + return x + y + z + len(q) + + self.assertEqual(None, xla.check_function_argument_count(func, 2, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, None)) + self.assertEqual(None, xla.check_function_argument_count(func, 5, None)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, xla.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 3, queue)) + self.assertEqual(None, xla.check_function_argument_count(func, 4, queue)) + self.assertEqual('at least 2 arguments', + xla.check_function_argument_count(func, 0, queue)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py index 41258edd90866ae9f644a02c42dfe2dc589da998..6926c0d03fe38ab2d62cc588950c7f5a49b2aba1 100644 --- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -74,8 +74,8 @@ class ConstrainedMinimizationProblem(object): if (constraints_shape.ndims is None or proxy_constraints_shape.ndims is None or - any([ii is None for ii in constraints_shape.as_list()]) or - any([ii is None for ii in proxy_constraints_shape.as_list()])): + any(ii is None for ii in constraints_shape.as_list()) or + any(ii is None for ii in proxy_constraints_shape.as_list())): raise ValueError( "constraints and proxy_constraints must have fully-known shapes") if constraints_shape != proxy_constraints_shape: diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py index 67f8ac2b9322f39b02c521f8b9cde3831c7889b8..fb0f849b33b0c5d28fff09eb5aac7f2c0d1adc0b 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -82,7 +82,7 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius): raise ValueError( "multipliers must be one dimensional (instead is %d-dimensional)" % multipliers_shape.ndims) - dimension = multipliers_shape[0].value + dimension = multipliers_shape.dims[0].value if dimension is None: raise ValueError("multipliers must have fully-known shape") diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index a6cb1f62f059770c90bd1aeea391d841aed9aacf..14e6d8701124ba67cdff8140250b5078f6194693 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -156,7 +156,7 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): if matrix_shape[0] != matrix_shape[1]: raise ValueError("matrix must be square (instead has shape (%d,%d))" % (matrix_shape[0], matrix_shape[1])) - dimension = matrix_shape[0].value + dimension = matrix_shape.dims[0].value if dimension is None: raise ValueError("matrix must have fully-known shape") @@ -601,7 +601,7 @@ class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): assert state_shape is not None assert state_shape.ndims == 2 assert state_shape[0] == state_shape[1] - dimension = state_shape[0].value + dimension = state_shape.dims[0].value assert dimension is not None minimum_log_multiplier = standard_ops.log( diff --git a/tensorflow/contrib/copy_graph/python/__init__.py b/tensorflow/contrib/copy_graph/python/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/__init__.py +++ b/tensorflow/contrib/copy_graph/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/copy_graph/python/util/__init__.py b/tensorflow/contrib/copy_graph/python/util/__init__.py index b9ff28eb0d7115ff5919c2f758f70ba388f5d4d2..5c1048e02a3104c958f7710ba97980d3353adbad 100644 --- a/tensorflow/contrib/copy_graph/python/util/__init__.py +++ b/tensorflow/contrib/copy_graph/python/util/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index ba97c7845635596c3f4f849044b6707ec43f5bbf..4d8651a79fde9b876d4fdd9b050e71d2eb7c893d 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -26,15 +26,16 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -graph1 = ops.Graph() -graph2 = ops.Graph() - class CopyVariablesTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testVariableCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Define a Variable in graph1 some_var = variables.VariableV1(2) #Initialize session @@ -43,13 +44,15 @@ class CopyVariablesTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #Make a copy of some_var in the defsult scope in graph2 - copy1 = copy_elements.copy_variable_to_graph(some_var, graph2) + copy1 = copy_elements.copy_variable_to_graph(some_var, self.graph2) #Make another copy with different scope - copy2 = copy_elements.copy_variable_to_graph(some_var, graph2, "test_scope") + copy2 = copy_elements.copy_variable_to_graph(some_var, + self.graph2, + "test_scope") #Initialize both the copies - with graph2.as_default(): + with self.graph2.as_default(): #Initialize Session sess2 = session_lib.Session() #Initialize the Variables @@ -67,9 +70,13 @@ class CopyVariablesTest(test.TestCase): class CopyOpsTest(test.TestCase): + def setUp(self): + self.graph1 = ops.Graph() + self.graph2 = ops.Graph() + def testOpsCopy(self): - with graph1.as_default(): + with self.graph1.as_default(): #Initialize a basic expression y = ax + b x = array_ops.placeholder("float") a = variables.VariableV1(3.0) @@ -82,21 +89,21 @@ class CopyOpsTest(test.TestCase): variables.global_variables_initializer().run(session=sess1) #First, initialize a as a Variable in graph2 - a1 = copy_elements.copy_variable_to_graph(a, graph2) + a1 = copy_elements.copy_variable_to_graph(a, self.graph2) #Initialize a1 in graph2 - with graph2.as_default(): + with self.graph2.as_default(): #Initialize session sess2 = session_lib.Session() #Initialize the Variable variables.global_variables_initializer().run(session=sess2) #Initialize a copy of y in graph2 - y1 = copy_elements.copy_op_to_graph(y, graph2, [a1]) + y1 = copy_elements.copy_op_to_graph(y, self.graph2, [a1]) #Now that y has been copied, x must be copied too. #Get that instance - x1 = copy_elements.get_copied_op(x, graph2) + x1 = copy_elements.get_copied_op(x, self.graph2) #Compare values of y & y1 for a sample input #and check if they match diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index fe5e34d258fbc1508a0a85655f29c2c9bc8fa8b1..d53549048f33162ec89dfe957ca58a4bbb4e95c6 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -14,8 +14,6 @@ # ============================================================================== """Linear-chain CRF layer. -See the [CRF](https://tensorflow.org/api_guides/python/contrib.crf) guide. - @@crf_binary_score @@crf_decode @@crf_log_likelihood diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 43bb43129bfe1cb1c66f4965476f9b7f849658ad..40e159b8fcbd1864284e208cb15d9ed96119f840 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -38,12 +38,12 @@ tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( [unary_scores, sequence_lengths, transition_params, train_op]) for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, tf_sequence_lengths): -# Remove padding. -tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] + # Remove padding. + tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] -# Compute the highest score and its tag sequence. -tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( - tf_unary_scores_, tf_transition_params) + # Compute the highest score and its tag sequence. + tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode( + tf_unary_scores_, tf_transition_params) """ from __future__ import absolute_import @@ -54,6 +54,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -107,8 +108,10 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, return sequence_scores return utils.smart_cond( - pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], - 1), + pred=math_ops.equal( + tensor_shape.dimension_value( + inputs.shape[1]) or array_ops.shape(inputs)[1], + 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) @@ -157,8 +160,10 @@ def crf_multitag_sequence_score(inputs, tag_bitmap, sequence_lengths, transition_params=transition_params) return utils.smart_cond( - pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], - 1), + pred=math_ops.equal( + tensor_shape.dimension_value( + inputs.shape[1]) or array_ops.shape(inputs)[1], + 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) @@ -214,8 +219,10 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): return log_norm return utils.smart_cond( - pred=math_ops.equal(inputs.shape[1].value or - array_ops.shape(inputs)[1], 1), + pred=math_ops.equal( + tensor_shape.dimension_value( + inputs.shape[1]) or array_ops.shape(inputs)[1], + 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) @@ -240,7 +247,7 @@ def crf_log_likelihood(inputs, provided by the caller or created in this function. """ # Get shape information. - num_tags = inputs.get_shape()[2].value + num_tags = tensor_shape.dimension_value(inputs.shape[2]) # Get the transition matrix if not provided. if transition_params is None: @@ -342,7 +349,7 @@ class CrfForwardRnnCell(rnn_cell.RNNCell): for the broadcast summation occurring within the cell. """ self._transition_params = array_ops.expand_dims(transition_params, 0) - self._num_tags = transition_params.get_shape()[0].value + self._num_tags = tensor_shape.dimension_value(transition_params.shape[0]) @property def state_size(self): @@ -428,7 +435,7 @@ class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): summation occurring within the cell. """ self._transition_params = array_ops.expand_dims(transition_params, 0) - self._num_tags = transition_params.get_shape()[0].value + self._num_tags = tensor_shape.dimension_value(transition_params.shape[0]) @property def state_size(self): @@ -540,7 +547,7 @@ def crf_decode(potentials, transition_params, sequence_length): # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.get_shape()[2].value + num_tags = tensor_shape.dimension_value(potentials.shape[2]) # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) @@ -583,7 +590,7 @@ def crf_decode(potentials, transition_params, sequence_length): return decode_tags, best_score return utils.smart_cond( - pred=math_ops.equal(potentials.shape[1].value or + pred=math_ops.equal(tensor_shape.dimension_value(potentials.shape[1]) or array_ops.shape(potentials)[1], 1), true_fn=_single_seq_fn, false_fn=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index aeefa3cee62281c74388765ea5e2cbc7f16ff927..8d35622e393e15a2f2dfea7c75ad2c9f48aa7150 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -9,8 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -44,10 +42,11 @@ tf_custom_op_py_library( cuda_py_test( name = "cudnn_rnn_ops_test", - size = "large", + size = "medium", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ ":cudnn_rnn_py", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", @@ -63,10 +62,10 @@ cuda_py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], - shard_count = 6, + shard_count = 2, tags = [ - "manual", - "requires_cudnn5", + "noasan", # http://b/62067814 + "requires-gpu-sm35", ], ) @@ -93,8 +92,7 @@ cuda_py_test( ], shard_count = 6, tags = [ - "manual", - "requires_cudnn5", + "noasan", # http://b/62067814 ], ) @@ -121,6 +119,5 @@ cuda_py_test( "noasan", # http://b/62067814 "nomsan", "notsan", - "requires_cudnn5", ], ) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 5d8c6191f8db9f96532aa78e4790a4665d3b4877..5320232268657fa73bcd3e86da49d6525e9b8db5 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -24,6 +24,10 @@ @@CudnnGRUSaveable @@CudnnRNNReluSaveable @@CudnnRNNTanhSaveable +@@CudnnParamsFormatConverterLSTM +@@CudnnParamsFormatConverterGRU +@@CudnnParamsFormatConverterTanh +@@CudnnParamsFormatConverterRelu """ from __future__ import absolute_import @@ -48,6 +52,10 @@ _allowed_symbols = [ "CudnnGRUSaveable", "CudnnRNNReluSaveable", "CudnnRNNTanhSaveable", + "CudnnParamsFormatConverterLSTM", + "CudnnParamsFormatConverterGRU", + "CudnnParamsFormatConverterTanh", + "CudnnParamsFormatConverterRelu", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index c59d3682d404e032d9f4bf81ef54ab456341cefa..a268415f0e65206294431a537be18cadbe1a1e84 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import itertools import os import unittest +from absl.testing import parameterized import numpy as np from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -56,710 +62,989 @@ CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER -def _CreateModel(rnn_mode, - num_layers, - num_units, - input_size, - input_mode="linear_input", - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0.): - del input_mode - if rnn_mode == cudnn_rnn_ops.CUDNN_LSTM: - model_fn = cudnn_rnn_ops.CudnnLSTM - elif rnn_mode == cudnn_rnn_ops.CUDNN_GRU: - model_fn = cudnn_rnn_ops.CudnnGRU - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_TANH: - model_fn = cudnn_rnn_ops.CudnnRNNTanh - elif rnn_mode == cudnn_rnn_ops.CUDNN_RNN_RELU: - model_fn = cudnn_rnn_ops.CudnnRNNRelu +def RunLSTM(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_c_op = variable_scope.get_variable( + "initial_c_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + + with variable_scope.variable_scope("test", initializer=initializer): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + num_units, num_units * 4], + dtype=dtype) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype) + + # canonical lstm. must set forget_bias to 0. to align with cudnn lstm. + cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True) + outputs_op, state_tuple_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=rnn_cell_impl.LSTMStateTuple( + h=initial_h_op, c=initial_c_op), + dtype=dtype, + time_major=True, + scope=None) + + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque([w, b]) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_initial_c_op = array_ops.expand_dims(initial_c_op, axis=0) + cu_outputs_op, cu_h_op, cu_c_op = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + cu_initial_c_op, + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_LSTM) + # Remove the trivial 1st dimension. + cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple( + c=array_ops.squeeze(cu_c_op, axis=0), + h=array_ops.squeeze(cu_h_op, axis=0)) + + if is_training: + (inp_grad_op, hgrad_op, + cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op, initial_c_op, w, b]) + + (cu_inp_grad_op, cu_hgrad_op, + cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, + [inputs, cu_initial_h_op, cu_initial_c_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + # Remove the trivial 1st dimension + cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + cu_wgrad_op = cu_wgrad_op[0] + cu_bgrad_op = cu_bgrad_op[0] + # cudnn lstm has 2 biases each gate. When converting to tf canonical format, + # the two biases are summed into one. Thus here bias gradient should be + # halved when comparing with tf lstm. + cu_bgrad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([ + outputs_op, state_tuple_op, inp_grad_op, + (hgrad_op, cgrad_op), wgrad_op, bgrad_op + ]) + (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad, + cu_bgrad) = sess.run([ + cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op, + (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op + ]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "state_grad: %s" % str(state_grad)) + logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad)) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, + cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, + cu_bgrad) else: - raise ValueError("Invalid rnn_mode: %s" % rnn_mode) - return model_fn( - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - - -def _CreateParamsSavable(params, - model, - base_variable_scope=None, - name="params_canonical"): - """Create a RNNParamsSaveable for the weight and bias parameters. + outputs, state_tuple = sess.run([outputs_op, state_tuple_op]) + cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op]) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "state_tuple: %s" % str(state_tuple)) + logging.vlog(1, "cu_state_tuple: %s" % str(cu_state_tuple)) + return outputs, cu_outputs, state_tuple, cu_state_tuple + + +# Basic set of RNN configs to test. They can be further extended in relevant +# test (e.g. adding num_dirs). +NAMED_RNN_TESTCASES = ({ + "testcase_name": "xsmall", + "num_units": 1, + "input_size": 1, + "batch_size": 1, + "time": 1, + "num_layers": 1, +}, { + "testcase_name": "small", + "num_units": 4, + "input_size": 4, + "batch_size": 4, + "time": 4, + "num_layers": 1, +}, { + "testcase_name": "medium", + "num_units": 128, + "input_size": 64, + "batch_size": 8, + "time": 16, + "num_layers": 1, +}, { + "testcase_name": "large", + "num_units": 128, + "input_size": 128, + "batch_size": 16, + "time": 32, + "num_layers": 1, +}) + + +def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs): + """Expands testcase with new config dimensions. + + Example: + inputs = ( + {'testcase_name': 'test1', 'gender': 'male'} + {'testcase_name': 'test2', 'gender': 'female'} + ) + remove_keys: empty + extra_configs = { + 'age': [40, 80] + 'height': [5, 6] + } + + Returns: + ( + {'testcase_name': 'test1_age_40_height_5','gender': 'male', 'age': + 40,'height': 5} + {'testcase_name': 'test1_age_40_height_6', 'gender': 'male', 'age': 40, + 'height': 6} + {'testcase_name': 'test1_age_80_height_5', 'gender': 'male', 'age': 80, + 'height': 5} + {'testcase_name': 'test1_age_80_height_6', 'gender': 'male', 'age': 80, + 'height': 6} + + {'testcase_name': 'test2_age_40_height_5', 'gender': 'female', 'age': + 40, + 'height': 5} + {'testcase_name': 'test2_age_40_height_6', 'gender': 'female', 'age': + 40, + 'height': 6} + {'testcase_name': 'test2_age_80_height_5', 'gender': 'female', 'age': + 80, + 'height': 5} + {'testcase_name': 'test2_age_80_height_6', 'gender': 'female', 'age': + 80, + 'height': 6} + ) Args: - params: a Variable for weight and bias parameters. - model: a CudnnRNN model. - base_variable_scope: a string, prefix of names of saved variables. - name: a string, name of the RNNParamsSaveable object. + inputs: A list of dictionary, each being a testcase. + *remove_keys: A list of keys into testcase which are not needed in new + testcases. + **extra_configs: A dict of new test dimension and applicable values in that + dimension. + Returns: - a RNNParamsSaveable object. + A list of dictionary with expanded test cases. """ - if model._rnn_mode == CUDNN_LSTM: - fn = cudnn_rnn_ops.CudnnLSTMSaveable - elif model._rnn_mode == CUDNN_GRU: - fn = cudnn_rnn_ops.CudnnGRUSaveable - elif model._rnn_mode == CUDNN_RNN_TANH: - fn = cudnn_rnn_ops.CudnnRNNTanhSaveable - elif model._rnn_mode == CUDNN_RNN_RELU: - fn = cudnn_rnn_ops.CudnnRNNReluSaveable - params_saveable = fn( - params, - model.num_layers, - model.num_units, - model.input_size, - model.input_mode, - model.direction, - scope=base_variable_scope, - name=name) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) - return params_saveable - - -def _MinLSTMParamSize(num_layers, - num_units, - input_size, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION): - if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = 8 * (num_layers - 1) * num_units * num_units - all_biases = 8 * num_layers * num_units - return first_layer_weights + higher_layer_weights + all_biases - elif direction == cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION: - first_layer_weights = 4 * num_units * (num_units + input_size) - higher_layer_weights = (num_layers - 1) * ( - 4 * 2 * num_units * num_units + 4 * num_units**2) - all_biases = 8 * num_layers * num_units - return 2 * (first_layer_weights + higher_layer_weights + all_biases) - else: - raise ValueError("%s direction is not supported.") + res = [] + ordered_extra_configs = collections.OrderedDict(extra_configs) + keys = ordered_extra_configs.keys() + # A list of list of configs. + # The outer loop is iterating keys, the innner is values of one key. + combined_kv = [[(k, v) for v in ordered_extra_configs[k]] for k in keys] + logging.info("combined_kv: %s", combined_kv) + for inp in inputs: + # Each inp is a dict + for config in itertools.product(*combined_kv): + new_inp = dict(inp) + # config is a list in the form of [(k_i, v_j), (k_p, v_q), ...] + suffix = ["%s_%s" % (p[0], str(p[1])) for p in config] + suffix = "_".join(suffix) + new_inp["testcase_name"] += "_" + suffix + for k, v in config: + new_inp[k] = v + # Remove not used keys from the new test case. + if remove_keys: + if not isinstance(remove_keys, (list, tuple)): + remove_keys = [remove_keys] + for k in remove_keys: + new_inp.pop(k, None) + logging.info("new_inp: %s", new_inp) + res.append(new_inp) + # Dedup, necessary if `remove_keys` is set. + return [dict(t) for t in {tuple(d.items()) for d in res}] -class CudnnRNNTestSaveRestore(TensorFlowTestCase): - def _CompareWeights(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - for lw, rw in zip(lhs, rhs): - self.assertAllEqual(lw, rw) +class CudnnLSTMTest(TensorFlowTestCase, parameterized.TestCase): - def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): - self.assertEqual(len(lhs), len(rhs)) - if rnn_mode == CUDNN_LSTM: - num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_GRU: - num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - elif rnn_mode == CUDNN_RNN_TANH: - num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER - else: - num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER - num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 - num_params_per_layer *= num_dirs - self.assertEqual(num_params_per_layer * num_layers, len(lhs)) - - for i in range(num_layers): - layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] - if direction == CUDNN_RNN_UNIDIRECTION: - self._CompareSingleLayerBiases(layer_lhs, layer_rhs) - else: - size = len(layer_lhs) - fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] - fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] - self._CompareSingleLayerBiases(fw_lhs, fw_rhs) - self._CompareSingleLayerBiases(bw_lhs, bw_rhs) - - def _CompareSingleLayerBiases(self, lhs, rhs): - self.assertEqual(len(lhs), len(rhs)) - - lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] - lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] - self.assertEqual(len(lf_lhs), len(rt_lhs)) - self.assertEqual(len(lf_rhs), len(rt_rhs)) - - sum_lhs, sum_rhs = [], [] - for lf, rt in zip(lf_lhs, rt_lhs): - sum_lhs.append(lf + rt) - for lf, rt in zip(lf_rhs, rt_rhs): - sum_rhs.append(lf + rt) - self.assertEqual(len(sum_lhs), len(sum_rhs)) - for lf, rt in zip(sum_lhs, sum_rhs): - self.assertAllEqual(lf, rt) + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad, + state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM( + sess, num_units, input_size, batch_size, time, num_layers) - def _testSaveRestoreVariable(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - params = variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) - saveable = _CreateParamsSavable(params, model) - weights, biases = saveable._OpaqueParamsToCanonical() - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + for s, cu_s in zip(state_tuple, cu_state_tuple): + self.assertAllClose(s, cu_s, rtol=rtol, atol=atol) + for sg, cu_sg in zip(state_grad, cu_state_grad): + self.assertAllClose(sg, cu_sg, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol) + self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol) - weights_v, biases_v = sess.run([weights, biases]) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) - sess.run(reset_params) - saver.restore(sess, save_path) - weights_v_restored, biases_v_restored = sess.run([weights, biases]) - - self._CompareWeights(weights_v, weights_v_restored) - self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): - num_layers = 2 - num_units = 7 - input_size = 3 - with ops.Graph().as_default(): - model = _CreateModel( - rnn_mode, - num_layers=num_layers, - num_units=num_units, - input_size=input_size, - direction=direction, - dtype=dtype) - random_seed.set_random_seed(1234) - params_size_t = model.params_size() - names = ["rnn_1", "rnn_2"] - param_vars = [ - variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - dtype=dtype, - validate_shape=False) for name in names - ] - saveables = [] - for name, params in zip(names, param_vars): - saveables.append(_CreateParamsSavable(params, model, name, name)) - weights1, biases1 = saveables[0]._OpaqueParamsToCanonical() - weights2, biases2 = saveables[1]._OpaqueParamsToCanonical() - reset_params = [ - state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) for params in param_vars - ] - save_path = os.path.join(self.get_temp_dir(), - "save-restore-variable-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session(use_gpu=True, - graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - weights1_v, biases1_v = sess.run([weights1, biases1]) - weights2_v, biases2_v = sess.run([weights2, biases2]) - - sess.run(reset_params) - saver.restore(sess, save_path) - weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) - weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) - - self._CompareWeights(weights1_v, weights1_v_restored) - self._CompareWeights(weights2_v, weights2_v_restored) - self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, - direction) - self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, - direction) - - def _testSaveRestoreOutput(self, rnn_mode, direction, dtype): - with ops.Graph().as_default(): - num_layers = 2 - num_units = 7 - input_size = 7 - seq_length = 10 - batch_size = 5 - dir_count = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 - model = _CreateModel( - rnn_mode, + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( + num_units, + input_size, + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, num_layers, + is_training=False) + + self.assertAllClose(outputs, cu_outputs) + # h + self.assertAllClose(state_tuple.h, cu_state_tuple.h) + # c + self.assertAllClose(state_tuple.c, cu_state_tuple.c) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM( + sess, num_units, input_size, - direction=direction, - dtype=dtype) - params_size_t = model.params_size() - params = variables.Variable( - array_ops.ones([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - _CreateParamsSavable(params, model) - save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) - np.random.seed(1234) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - input_data = constant_op.constant( - np.random.randn(seq_length, batch_size, input_size), dtype=dtype) - input_h = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - if has_input_c: - input_c = constant_op.constant( - np.random.randn(num_layers * dir_count, batch_size, num_units), - dtype=dtype) - outputs = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - outputs = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - total_sum = sum(map(math_ops.reduce_sum, outputs)) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run(total_sum) - val = saver.save(sess, save_path) - self.assertEqual(save_path, val) - # Passing graph explicitly, otherwise an old sess would be reused. - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - reset_params = state_ops.assign( - params, - array_ops.zeros([params_size_t], dtype=dtype), - validate_shape=False) - sess.run(reset_params) - saver.restore(sess, save_path) - total_sum_v_restored = sess.run(total_sum) - self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + # h + self.assertAllClose( + state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol) + # c + self.assertAllClose( + state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSaveRestore(self): - rnn_modes = [ - cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU, - cudnn_rnn_ops.CUDNN_RNN_TANH, cudnn_rnn_ops.CUDNN_RNN_RELU - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - dtype_list = [dtypes.float32, dtypes.float64] - for rnn_mode, direction, dtype in itertools.product(rnn_modes, directions, - dtype_list): - self._testSaveRestoreVariable(rnn_mode, direction, dtype) - self._testSaveRestoreTwoVariables(rnn_mode, direction, dtype) - self._testSaveRestoreOutput(rnn_mode, direction, dtype) - - -class CudnnRNNTestParamsSize(TensorFlowTestCase): - - def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, - direction): - logging.info("Testing one lstm param size with config: %s", locals()) - min_params_size = _MinLSTMParamSize(num_layers, num_units, input_size, - direction) - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - num_layers, + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + # Hand-picked dropouts are used below (0. and 1.) + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_state_tuple) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + # h + self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h) + # c + self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c) + + +def RunGRU(sess, + num_units, + input_size, + batch_size, + time, + num_layers=1, + is_training=True, + dropout=0., + num_dirs=True, + dtype=dtypes.float32): + # TODO(jamesqin): add multi-layer tests. + # TODO(jamesqin): add multi-dir tests + assert num_layers == 1 + assert num_dirs == 1 + if is_training and not np.isclose(dropout, 0): + raise ValueError("dropout can not be 0. when test training.") + + # set graph level random seed and numpy random seed. + random_seed.set_random_seed(0) + np.random.seed(0) + + inputs = variable_scope.get_variable( + "inputs", + initializer=np.random.rand(time, batch_size, + input_size).astype(dtype.as_numpy_dtype), + dtype=dtype) + initial_h_op = variable_scope.get_variable( + "initial_h_op", + initializer=np.random.rand(batch_size, + num_units).astype(dtype.as_numpy_dtype), + dtype=dtype) + + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, dtype=dtype, seed=19980904) + with variable_scope.variable_scope("test", initializer=initializer): + gate_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/kernel", + shape=[input_size + num_units, num_units * 2], + dtype=dtype) + gate_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/gates/bias", + shape=[num_units * 2], + dtype=dtype) + candidate_inp_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/kernel", + shape=[input_size, num_units], + dtype=dtype) + candidate_inp_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/input_projection/bias", + shape=[num_units], + dtype=dtype) + candidate_hid_kernel = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/kernel", + shape=[num_units, num_units], + dtype=dtype) + candidate_hid_bias = variable_scope.get_variable( + "rnn/cudnn_compatible_gru_cell/candidate/hidden_projection/bias", + shape=[num_units], + dtype=dtype) + + cell = cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units, reuse=True) + outputs_op, h_op = rnn.dynamic_rnn( + cell, + inputs, + initial_state=initial_h_op, + dtype=dtype, + time_major=True, + scope=None) + + ws = [gate_kernel, candidate_inp_kernel, candidate_hid_kernel] + bs = [gate_bias, candidate_inp_bias, candidate_hid_bias] + # Convert to cudnn opaque param. + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size) + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + + cu_initial_h_op = array_ops.expand_dims(initial_h_op, axis=0) + cu_outputs_op, cu_h_op, _ = cudnn_rnn_ops._cudnn_rnn( + inputs, + cu_initial_h_op, + array_ops.zeros_like(cu_initial_h_op), # not used + opaque_params, + dropout=dropout, + is_training=is_training, + rnn_mode=cudnn_rnn_ops.CUDNN_GRU) + + if is_training: + (inp_grad_op, hgrad_op, gk_grad_op, cik_grad_op, chk_grad_op, gb_grad_op, + cib_grad_op, chb_grad_op) = gradients_impl.gradients( + outputs_op, [inputs, initial_h_op] + ws + bs) + + (cu_inp_grad_op, cu_hgrad_op, opaque_grad_op) = gradients_impl.gradients( + cu_outputs_op, [inputs, cu_initial_h_op, opaque_params]) + # Remove the trivial 1st dimension + cu_hgrad_op = array_ops.squeeze(cu_hgrad_op, axis=0) + + cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical( + opaque_grad_op) + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op) = cu_wgrad_op + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) = cu_bgrad_op + # cudnn gru has 2 biases for reset and update gates. When converting to tf + # canonical format, the two biases are summed into one. Thus here relevant + # bias gradient should be halved before comparing with tf gru. + cu_gb_grad_op *= 0.5 + + init_op = variables.global_variables_initializer() + sess.run(init_op) + + if is_training: + outputs, h, inp_grad, hgrad, wgrad, bgrad = sess.run([ + outputs_op, h_op, inp_grad_op, hgrad_op, + (gk_grad_op, cik_grad_op, chk_grad_op), + (gb_grad_op, cib_grad_op, chb_grad_op) + ]) + (cu_outputs, cu_h, cu_inp_grad, cu_hgrad, cu_wgrad, cu_bgrad) = sess.run([ + cu_outputs_op, cu_h_op, cu_inp_grad_op, cu_hgrad_op, + (cu_gk_grad_op, cu_cik_grad_op, cu_chk_grad_op), + (cu_gb_grad_op, cu_cib_grad_op, cu_chb_grad_op) + ]) + # Remove the trivial 1st dimension + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + logging.vlog(1, "inp_grad: %s" % inp_grad) + logging.vlog(1, "cu_inp_grad: %s" % cu_inp_grad) + logging.vlog(1, "hgrad: %s" % hgrad) + logging.vlog(1, "cu_hgrad: %s" % cu_hgrad) + logging.vlog(1, "wgrad: %s" % str(wgrad)) + logging.vlog(1, "bgrad: %s" % str(bgrad)) + logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad)) + logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad)) + return (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) + else: + outputs, h = sess.run([outputs_op, h_op]) + cu_outputs, cu_h = sess.run([cu_outputs_op, cu_h_op]) + # Remove the trivial 1st dimension. + cu_h = np.squeeze(cu_h, axis=0) + + logging.vlog(1, "outputs: %s" % outputs) + logging.vlog(1, "cu_outputs: %s" % cu_outputs) + logging.vlog(1, "h: %s" % h) + logging.vlog(1, "cu_h: %s" % h) + return outputs, cu_outputs, h, cu_h + + +class CudnnGRUTest(TensorFlowTestCase, parameterized.TestCase): + + def _test_training_helper(self, + num_units, + input_size, + batch_size, + time, + num_layers, + dtype, + rtol=2e-6, + atol=2e-6): + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h, inp_grad, cu_inp_grad, hgrad, + cu_hgrad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunGRU( + sess, num_units, input_size, batch_size, time, num_layers) + + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + self.assertAllClose(hgrad, cu_hgrad, rtol=rtol, atol=atol) + self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol) + for bg, cu_bg in zip(bgrad, cu_bgrad): + self.assertAllClose(bg, cu_bg, rtol=rtol, atol=atol) + for wg, cu_wg in zip(wgrad, cu_wgrad): + self.assertAllClose(wg, cu_wg, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper(num_units, input_size, batch_size, time, + num_layers, dtypes.float32) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_training_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_training_helper( num_units, input_size, - direction=direction) - params_size = model.params_size() - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size) - self.assertLessEqual(min_params_size, params_size_v) + batch_size, + time, + num_layers, + dtypes.float16, + rtol=5e-3, + atol=5e-4) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSize(self): - test_configs = [ - [4, 200, 200], - [4, 200, 300], - [4, 200, 100], - [1, 100, 200], - [2, 200, 100], - [3, 200, 400], - ] - directions = [ - cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION, - cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION - ] - for (config, direction) in itertools.product(test_configs, directions): - num_layers, num_units, input_size = config - with ops.Graph().as_default(): - self._testOneLSTMParamsSize(num_layers, num_units, input_size, - direction) + def test_inference(self, num_units, input_size, batch_size, time, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False) + self.assertAllClose(outputs, cu_outputs) + self.assertAllClose(h, cu_h) + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testLSTMParamsSizeShape(self): - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - constant_op.constant([4]), 200, 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( - cudnn_rnn_ops.CUDNN_LSTM, - 4, constant_op.constant([200]), 200, - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() - with self.assertRaisesRegexp( - ValueError, "Shape must be rank 0 but is rank 1"): - model = _CreateModel( + def test_inference_fp16(self, num_units, input_size, batch_size, time, + num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + (outputs, cu_outputs, h, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dtype=dtypes.float16) + + rtol, atol = 5e-3, 5e-4 + self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol) + self.assertAllClose(h, cu_h, rtol=rtol, atol=atol) + + @parameterized.named_parameters(*NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_inference_with_dropout(self, num_units, input_size, batch_size, time, + num_layers): + """Validates that dropout does not affect Cudnn Rnn inference.""" + # Hand-picked dropouts are used below (0. and 1.) + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + # 1st time w/o dropout. + (_, cu_outputs, _, cu_h) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=0.) + + with ops.Graph().as_default() as g: + with self.session(use_gpu=True, graph=g) as sess: + (_, cu_outputs2, _, cu_h2) = RunGRU( + sess, + num_units, + input_size, + batch_size, + time, + num_layers, + is_training=False, + dropout=1.) + + self.assertAllClose(cu_outputs, cu_outputs2) + self.assertAllClose(cu_h[0], cu_h2[0]) + + +class CudnnParamsFormatConverterTest(TensorFlowTestCase, + parameterized.TestCase): + """Class for testing various format converters.""" + + def _test_lstm_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM( + num_layers, num_units, input_size, direction=direction) + + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + w = constant_op.constant( + np.random.rand(input_size + num_units, 4 * num_units), + dtype=dtypes.float32) + b = constant_op.constant( + np.random.rand(4 * num_units), dtype=dtypes.float32) + ws.append(w) + bs.append(b) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( cudnn_rnn_ops.CUDNN_LSTM, - 4, 200, constant_op.constant([200]), - direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) - params_size = model.params_size() + num_layers, + num_units, + input_size, + direction=direction) + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) -class CudnnRNNTestInference(TensorFlowTestCase): + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) - def _testOneSimpleInference(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, - expected, tolerance): - random_seed.set_random_seed(5678) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - input_mode="auto_select", - direction=(cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION), - dropout=dropout) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - params_size_t = model.params_size() - input_data = array_ops.ones([seq_length, batch_size, input_size]) - input_h = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - params = variables.Variable( - array_ops.ones([params_size_t]), validate_shape=False) - if has_input_c: - input_c = array_ops.ones([num_layers * dir_count, batch_size, num_units]) - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params, - is_training=False) - else: - output, output_h = model( - input_data=input_data, - input_h=input_h, - params=params, - is_training=False) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - total_sum_v = sess.run([total_sum]) + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + np.sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) - self.assertAllClose( - total_sum_v[0], expected, atol=tolerance, rtol=tolerance) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_lstm(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleInference(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "expected": 231833.22, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "expected": 56000, - "tolerance": 1e-2, - "shape": { - "num_layers": 4, - "num_units": 200, - "input_size": 200, - "batch_size": 20, - "seq_length": 10, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "expected": 130688, - "tolerance": 1e-2, - "shape": { - "num_layers": 2, - "num_units": 8, - "input_size": 4, - "batch_size": 4, - "seq_length": 2, - "dir_count": 1, - }, - }, - ] - # Cudnn scales result for dropout during training, therefore dropout has no - # impact for inference results. - # (lstm, gru, rnn_tanh are saturated in the test. rnn_relu case is most - # demonstrative of the dropout-invariant nature of CudnnRnn.) - dropouts = [0., 0.5, 1.] - for (config, dropout) in itertools.product(test_configs, dropouts): - rnn_mode = config["rnn_mode"] - expected = config["expected"] - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleInference( - rnn_mode, shape["num_layers"], shape["num_units"], - shape["input_size"], shape["batch_size"], shape["seq_length"], - shape["dir_count"], dropout, expected, tolerance) - - -class CudnnRNNTestTraining(TensorFlowTestCase): - - def _testOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, - batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): - # Gradient checking runs two forward ops with almost the same input. Need to - # make sure the drop patterns across the two runs are the same. - logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - random_seed.set_random_seed(5678) - direction = (cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION if dir_count == 1 - else cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) - model = _CreateModel( - rnn_mode, - num_layers, - num_units, - input_size, - direction=direction, - dtype=dtype, - dropout=dropout) - params_size_t = model.params_size() - input_data = variables.Variable( - random_ops.random_uniform( - [seq_length, batch_size, input_size], dtype=dtype), - dtype=dtype) - input_h = variables.Variable( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - params = variables.Variable( - random_ops.random_uniform([params_size_t], dtype=dtype), - validate_shape=False, - dtype=dtype) - if has_input_c: - input_c = variables.Variable( - random_ops.random_uniform( - [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) - - output, output_h, output_c = model( - input_data=input_data, - input_h=input_h, - input_c=input_c, - params=params) - else: - output, output_h = model( - input_data=input_data, input_h=input_h, params=params) - output_sum = math_ops.reduce_sum(output) - output_h_sum = math_ops.reduce_sum(output_h) - total_sum = output_sum + output_h_sum - if has_input_c: - output_c_sum = math_ops.reduce_sum(output_c) - total_sum += output_c_sum - - with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: - params_size_v = sess.run(params_size_t) - inputs_and_shapes = [ - (input_data, [seq_length, batch_size, input_size]), - (input_h, [num_layers * dir_count, batch_size, num_units]), - (params, [params_size_v]), - ] - if has_input_c: - inputs_and_shapes.append( - (input_c, [num_layers * dir_count, batch_size, num_units]),) - sess.run(variables.global_variables_initializer()) - all_inputs = [entry[0] for entry in inputs_and_shapes] - all_shapes = [entry[1] for entry in inputs_and_shapes] - - err = gradient_checker.compute_gradient_error( - all_inputs, all_shapes, total_sum, [1], delta=delta) - - self.assertLess(err, tolerance) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + def test_lstm_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_lstm_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + def _test_gru_helper(self, num_units, input_size, num_layers, direction): + with self.session(use_gpu=True) as sess: + random_seed.set_random_seed(0) + np.random.seed(0) + + num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2 + format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterGRU( + num_layers, num_units, input_size, direction=direction) + ws, bs = [], [] + for _ in range(num_layers * num_dirs): + gate_kernel = constant_op.constant( + np.random.rand(input_size + num_units, num_units * 2), + dtype=dtypes.float32) + gate_bias = constant_op.constant( + np.random.rand(num_units * 2), dtype=dtypes.float32) + candidate_inp_kernel = constant_op.constant( + np.random.rand(input_size, num_units), dtype=dtypes.float32) + candidate_inp_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + candidate_hid_kernel = constant_op.constant( + np.random.rand(num_units, num_units), dtype=dtypes.float32) + candidate_hid_bias = constant_op.constant( + np.random.rand(num_units), dtype=dtypes.float32) + ws.extend([gate_kernel, candidate_inp_kernel, candidate_hid_kernel]) + bs.extend([gate_bias, candidate_inp_bias, candidate_hid_bias]) + + opaque_params = format_converter.tf_canonical_to_opaque(ws + bs) + opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + cudnn_rnn_ops.CUDNN_GRU, + num_layers, + num_units, + input_size, + direction=direction) + + ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params) + + # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical() + # returns the original input. + ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r]) + for w, w_r in zip(ws, ws_r): + self.assertAllClose(w, w_r) + for b, b_r in zip(bs, bs_r): + self.assertAllClose(b, b_r) + + # Test opaque_params size lower bound + opaque_params_size_v = sess.run(opaque_params_size) + min_params_size = sum(x.size for x in ws) + sum(x.size for x in bs) + logging.info("min_parm_size: %d vs actual_opaque_param_size: %d", + min_params_size, opaque_params_size_v) + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_gru(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION) + + @parameterized.named_parameters((c["testcase_name"], c["num_units"], + c["input_size"], c["num_layers"]) + for c in NAMED_RNN_TESTCASES) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTraining(self): - test_configs = [ - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float64, - "delta": 1e-4, - "tolerance": 5e-6, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - "dir_count": 1, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_LSTM, - "dtype": dtypes.float32, - "tolerance": 1.5e-2, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_GRU, - "dtype": dtypes.float32, - "tolerance": 4e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_TANH, - "dtype": dtypes.float32, - "tolerance": 5e-3, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - { - "rnn_mode": cudnn_rnn_ops.CUDNN_RNN_RELU, - "dtype": dtypes.float32, - "tolerance": 5e-1, - "shape": { - "num_layers": 2, - "num_units": 3, - "input_size": 4, - "batch_size": 3, - "seq_length": 4, - }, - }, - ] - dropouts = [0., 0.5, 1.] - dir_counts = [1] - for config, dropout, dir_count in itertools.product(test_configs, dropouts, - dir_counts): - rnn_mode = config["rnn_mode"] - dtype = config.get("dtype", dtypes.float32) - delta = config.get("delta", 1e-3) - tolerance = config["tolerance"] - shape = config["shape"] - with ops.Graph().as_default(): - self._testOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, tolerance) + def test_gru_bidi(self, num_units, input_size, num_layers): + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + self._test_gru_helper(num_units, input_size, num_layers, + cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION) + + +class CudnnRnnSaveRestoreTest(TensorFlowTestCase, parameterized.TestCase): + """Class for testing various Cudnn Rnn SaveableObjects.""" + + def _create_opaque_param(self, + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name=None): + param_size_t = cudnn_rnn_ops.cudnn_rnn_opaque_params_size( + rnn_mode, num_layers, num_units, input_size, direction=direction) + init_val = random_ops.random_uniform([param_size_t]) + return variable_scope.get_variable( + name or "opaque_param", initializer=init_val, validate_shape=False) + + def _create_saveable(self, opaque_param, rnn_mode, num_units, input_size, + num_layers, direction): + if rnn_mode == CUDNN_LSTM: + fn = cudnn_rnn_ops.CudnnLSTMSaveable + elif rnn_mode == CUDNN_GRU: + fn = cudnn_rnn_ops.CudnnGRUSaveable + elif rnn_mode == CUDNN_RNN_TANH: + fn = cudnn_rnn_ops.CudnnRNNTanhSaveable + elif rnn_mode == CUDNN_RNN_RELU: + fn = cudnn_rnn_ops.CudnnRNNReluSaveable + saveable = fn( + opaque_param, num_layers, num_units, input_size, direction=direction) + return saveable + + def _compare_weights(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _compare_biases(self, lhs, rhs): + self.assertLen(rhs, len(lhs)) + for lf, rt in zip(lhs, rhs): + self.assertAllEqual(lf, rt) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_variable(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_param = self._create_opaque_param(rnn_mode, num_units, input_size, + num_layers, direction) + saveable = self._create_saveable(opaque_param, rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + weights_op, biases_op = saveable.format_converter.opaque_to_tf_canonical( + saveable._variables) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_op = state_ops.assign(opaque_param, + array_ops.zeros_like(opaque_param)) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + weights, biases = sess.run([weights_op, biases_op]) + + # Reset the opaque param value + sess.run(reset_op) + # Assert reset happened. + weights_z, biases_z = sess.run([weights_op, biases_op]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_op, biases_op]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) + + @parameterized.named_parameters( + ExpandNamedTestCases( + NAMED_RNN_TESTCASES, "time", "batch_size", **{ + "rnn_mode": [ + CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH + ], + "direction": [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + })) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def test_save_restore_multi_variables(self, rnn_mode, num_units, input_size, + num_layers, direction): + # Verify the restored opaque param, once converted to tf_canonical format, + # is the same as the tf canonicals of the pre-restored param. + if not context.context().num_gpus(): + self.skipTest("No GPUs found") + with self.session(use_gpu=True) as sess: + opaque_params = [] + saveables = [] + num_opaque_params = 2 + for i in range(num_opaque_params): + opaque_params.append( + self._create_opaque_param( + rnn_mode, + num_units, + input_size, + num_layers, + direction, + name="opaque_param_%d" % i)) + saveable = self._create_saveable(opaque_params[i], rnn_mode, num_units, + input_size, num_layers, direction) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saveables.append(saveable) + + weights_ops, biases_ops = [], [] + for i in range(num_opaque_params): + weights_op, biases_op = ( + saveables[i].format_converter.opaque_to_tf_canonical( + saveables[i]._variables)) + weights_ops.append(weights_op) + biases_ops.append(biases_op) + + save_path = os.path.join(self.get_temp_dir(), "save_restore_var_test") + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + init_op = variables.global_variables_initializer() + reset_ops = [] + for i in range(num_opaque_params): + reset_ops.append( + state_ops.assign(opaque_params[i], + array_ops.zeros_like(opaque_params[i]))) + sess.run(init_op) + self.assertEqual(save_path, saver.save(sess, save_path)) + + # Get the tf canonical vals before reset-restore + for i in range(num_opaque_params): + weights, biases = sess.run([weights_ops[i], biases_ops[i]]) + + # Reset the opaque param value + sess.run(reset_ops[i]) + + # Assert reset happened. + weights_z, biases_z = sess.run([weights_ops[i], biases_ops[i]]) + for w in weights_z: + self.assertAllClose(w, np.zeros_like(w)) + for b in biases_z: + self.assertAllClose(b, np.zeros_like(b)) + + # Restore opaque param value from checkpoint. + saver.restore(sess, save_path) + weights_r, biases_r = sess.run([weights_ops[i], biases_ops[i]]) + self._compare_weights(weights, weights_r) + self._compare_biases(biases, biases_r) if __name__ == "__main__": diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 57793a8ff5e2ec49dfea42c08eb9456cb2875eab..7e1b4062ce435f3ab4216e90b4f5fcbab984c1dc 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -536,7 +536,9 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): save_path = os.path.join(self.get_temp_dir(), "save-restore-variable-test") saver = saver_lib.Saver() - weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + weights, biases = ( + model.rnn.saveable.format_converter._opaque_to_cu_canonical( + model.rnn.saveable._variables)) opaque_params = rnn.trainable_variables[0] # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save # Cudnn vars in canonical format. @@ -583,8 +585,12 @@ class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): dtype=dtype) opaque_params = (model1.rnn.trainable_variables[0], model2.rnn.trainable_variables[0]) - weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() - weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + saveable1 = model1.rnn.saveable + weights1, biases1 = saveable1.format_converter._opaque_to_cu_canonical( + saveable1._variables) + saveable2 = model1.rnn.saveable + weights2, biases2 = saveable2.format_converter._opaque_to_cu_canonical( + saveable2._variables) reset_params = [ state_ops.assign(params, array_ops.zeros_like(params, dtype=dtype)) @@ -1039,8 +1045,8 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): # Min param size estimate = sum(weights.size) + sum(biases.size) min_params_size = ( - np.sum(map(np.prod, rnn.canonical_weight_shapes)) + - np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + sum(map(np.prod, rnn.canonical_weight_shapes)) + + sum(sp[0] for sp in rnn.canonical_bias_shapes)) opaque_params = rnn.trainable_variables[0] with self.test_session(use_gpu=True, graph=ops.get_default_graph()): @@ -1184,7 +1190,8 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] self.assertEqual(len(sym_grads), len(num_grads)) - for sym, num in zip(sym_grads, num_grads): + for x, sym, num in zip(xs, sym_grads, num_grads): + logging.info("Comparing gradients for input: %s", x.name) self.assertFalse(np.any(np.isnan(sym))) self.assertFalse(np.any(np.isnan(num))) self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) @@ -1225,18 +1232,18 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): params = rnn.trainable_variables[0] inputs = variables.Variable( - random_ops.random_uniform( - [seq_length, batch_size, input_size], dtype=dtype), - dtype=dtype) + random_ops.random_uniform([seq_length, batch_size, input_size], + dtype=dtype), + dtype=dtype).read_value() input_h = variables.Variable( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) + dtype=dtype).read_value() if has_input_c: input_c = variables.Variable( random_ops.random_uniform( [num_layers * dir_count, batch_size, num_units], dtype=dtype), - dtype=dtype) + dtype=dtype).read_value() initial_state = (input_h, input_c) else: initial_state = (input_h,) @@ -1262,7 +1269,7 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): dropouts = [0, 0.5, 1.] - v2_options = [str(False), str(True)] + v2_options = [False, True] for config, dropout, use_v2 in itertools.product(test_configs, dropouts, v2_options): dtype = config.get("dtype", dtypes.float32) @@ -1270,6 +1277,9 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): tolerance = config.get("tolerance", 1e-6) dir_count = config.get("dir_count", 1) shape = config["shape"] + if dtype == dtypes.float64: + # TODO(jamesqin): b/117848763 + use_v2 = False with ops.Graph().as_default(): self._TestOneSimpleTraining( rnn_mode, shape["num_layers"], shape["num_units"], @@ -1519,7 +1529,7 @@ if __name__ == "__main__": parser.add_argument( "--grad_check_num_samples", type=int, - default=5, + default=1, help="Number of samples to run for gradient check.") FLAGS, unparsed = parser.parse_known_args() sys.argv = [argv0] + unparsed diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index f09466b631f69d6234573dd5eafada650421c117..60229af374be869005139921483793156e5e7a05 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -27,5 +27,10 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibl from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterGRU +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterLSTM +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterRelu +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnParamsFormatConverterTanh from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index e26d56c8579e110d61c73c6154b82f47f0093687..8e25637ed91a1559b321ea96efbfaa2910f67158 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -21,6 +21,7 @@ from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.engine import input_spec from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -322,7 +323,7 @@ class _CudnnRNN(base_layer.Layer): raise ValueError("The last dimension of the inputs to `CudnnRNN` " "should be defined. Found `None`.") self._input_size = input_shape[-1].value - self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + self.input_spec = input_spec.InputSpec(ndim=3, axes={-1: self._input_size}) self._set_scope(None) @@ -356,7 +357,8 @@ class _CudnnRNN(base_layer.Layer): "Partitioner is not supported for Cudnn RNN layer variables, using " "it will create forward-compatibility issues with future " "CUDA/CuDNN generations.") - # Initialize opaque params with a tensor. + # Initialize opaque params with a tensor with unknown shape, thus couldn't + # use self.add_variable(name, shape, initializer, ...) self.kernel = vs.get_variable( "opaque_kernel", dtype=self._plain_dtype, initializer=opaque_params_t, validate_shape=False) @@ -387,11 +389,11 @@ class _CudnnRNN(base_layer.Layer): output_states: a tuple of tensor(s) of the same shape and structure as `initial_state`. Raises: - ValueError: initial_state is not a tuple. + TypeError: initial_state is not a tuple. """ if initial_state is not None and not isinstance(initial_state, tuple): - raise ValueError("Invalid initial_state type: %s, expecting tuple.", - type(initial_state)) + raise TypeError("Invalid initial_state type: %s, expecting tuple." % + initial_state) dtype = self.dtype inputs = ops.convert_to_tensor(inputs, dtype=dtype) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 2c92f31788378c2a9f01183bc04b035668b59b59..1ce29b42d52ff67477161278ed11016c2e73041d 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -74,7 +74,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): - """Cudnn Compatible GRUCell. + r"""Cudnn Compatible GRUCell. A GRU impl akin to `tf.nn.rnn_cell.GRUCell` to use along with `tf.contrib.cudnn_rnn.CudnnGRU`. The latter's params can be used by @@ -177,172 +177,60 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): return new_h, new_h -# TODO(yaozhang): make sure we only save the canonical version of params and -# don't save the platform-specific version to avoid potential race -# conditions where params is updated by both versions when being restored. -# Currently, checkpointing will function properly, despite that we save both -# versions, because Saver restores customized savables after Variables. -# However, it is good to not rely on this restoring order of Saver and to -# avoid unnecessary storage. Add a test to check only the canonical version is -# saved. -class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): - """Abstract SaveableObject implementation handling Cudnn opaque params.""" +class CudnnParamsFormatConverter(object): + """Abstract class that converts between params of Cudnn Rnn and TF Rnn.""" def __init__(self, - opaque_params, num_layers, num_units, input_size, input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - scope=None, - name="cudnn_rnn_saveable"): - """Creates a CudnnOpaqueParamsSaveable object. - - CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file - and is used to save/restore the weights and biases parameters in a - canonical format which is directly consumable by platform-independent tf - RNN cells. Parameters are saved as tensors layer by layer with weight - tensors followed by bias tensors, and forward direction followed by - backward direction (if applicable). When restoring, a user could name - param_variables as desired, and restore weight and bias tensors to these - variables. - - For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per - bias for each layer: tensor 0 is applied to the input from the previous - layer and tensor 1 to the recurrent input. - - For CudnnLSTM, there are 8 tensors per weight and per bias for each - layer: tensor 0-3 are applied to the input from the previous layer and - tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; - tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; - tensor 3 and 7 the output gate. - - For CudnnGRU, there are 6 tensors per weight and per bias for each layer: - tensor 0-2 are applied to the input from the previous layer and - tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; - tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + direction=CUDNN_RNN_UNIDIRECTION): + """Constructor. Args: - opaque_params: a variable, Cudnn RNN opaque params. num_layers: the number of layers for the RNN model. num_units: the number of units within the RNN model. input_size: the size of the input, it could be different from the - num_units. + num_units. input_mode: indicate whether there is a linear projection between the - input and the actual computation before the first layer. It could be - 'linear_input', 'skip_input' or 'auto_select'. - 'linear_input' (default) always applies a linear projection of input - onto RNN hidden state. (standard RNN behavior). - 'skip_input' is only allowed when input_size == num_units; - 'auto_select' implies 'skip_input' when input_size == num_units; - otherwise, it implies 'linear_input'. + input and the actual computation before the first layer. It could be one + of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input' + (default) always applies a linear projection of input onto RNN hidden + state. (standard RNN behavior). * 'skip_input' is only allowed when + input_size == num_units; * 'auto_select' implies 'skip_input' when + input_size == num_units; otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Could be either - 'unidirectional' or 'bidirectional' - scope: string of VariableScope, the scope of equivalent subgraph - consisting only platform-independent tf RNN cells. - name: the name of the CudnnOpaqueParamsSaveable object. + 'unidirectional' or 'bidirectional' """ - # Define in subclasses. self._num_layers = num_layers self._input_size = input_size self._num_units = num_units self._input_mode = input_mode self._direction = direction - if scope is not None: - scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope - self._scope = scope_name or None - else: - self._scope = None - - self._variables = opaque_params self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 self._num_params = ( self._num_params_per_layer * self._num_layers * self._num_dirs) - weights, biases = self._OpaqueParamsToCanonical() - (weights, weight_names), (biases, bias_names) = self._TransformCanonical( - weights, biases) - # We currently don't use slice_spec. It might be useful in a distributed - # setting where each parameter server node stores a slice of variable, - # instead of having the master pull all slices and then save them. - slice_spec = "" - params = weights + biases - self._weight_names = weight_names - self._bias_names = bias_names - self._param_names = weight_names + bias_names - prefixed_param_names = weight_names + bias_names - if self._scope: - prefixed_param_names = [ - "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] - specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, prefixed_param_names) - ] - super(CudnnOpaqueParamsSaveable, self).__init__( - array_ops.identity(self._variables), specs, name) - - def restore(self, restored_tensors, restored_shapes): - weights, biases = self._ReverseTransformCanonical(restored_tensors) - weights = [array_ops.reshape(w, [-1]) for w in weights] - opaque_params = self._CanonicalToOpaqueParams(weights, biases) - - return state_ops.assign( - self._variables, opaque_params, validate_shape=False) + def tf_canonical_to_opaque(self, tf_canonicals): + r"""Converts tf canonical weights to cudnn opaque param.""" + cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals) + cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights] + opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases) + return opaque_params - def _checkpointable_save(self, save_buffer): - weights, biases = self._OpaqueParamsToCanonical() - with ops.device("gpu:0"): - (weights, _), (biases, _) = self._TransformCanonical( - weights, biases) - for name, tensor in zip(self._param_names, weights + biases): - save_buffer[name] = array_ops.identity(tensor) + def opaque_to_tf_canonical(self, opaque_param): + r"""Converts cudnn opaque param to tf canonical weights.""" + cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param) + weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases) + return weights, biases - def _checkpointable_restore(self, restore_buffer): - tensors = [array_ops.identity(restore_buffer[name]) - for name in self._param_names] - return self.restore( - restored_tensors=tensors, - restored_shapes=None # Unused - ) - - def _add_checkpointable_dependencies(self, checkpointable, dtype): - """Add canonical weight dependencies to `checkpointable`. - - When saving or restoring, converts to or from the opaque buffer - format. Weights are saved and loaded in the configuration expected by - cuDNN-compatible cells. - - Args: - checkpointable: An object inheriting from `CheckpointableBase` to add - dependencies too (typically the cuDNN `Layer`). - dtype: The dtype for the canonical parameter Tensors. - """ - split_dependencies = split_dependency.split_dependency( - component_names=self._param_names, - component_dtypes=(dtype,) * len(self._param_names), - fill_save_buffer_fn=self._checkpointable_save, - consume_restore_buffer_fn=self._checkpointable_restore) - self._checkpointable_track_params(checkpointable, split_dependencies) - - def _checkpointable_track_params(self, checkpointable, params): - """Tracks parameters in a canonical configuration.""" - return # NotImplementedError raised by the Layer. - - def _TFCanonicalNamePrefix(self, layer, is_fwd=True): - if self._direction == CUDNN_RNN_UNIDIRECTION: - return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) - else: - if is_fwd: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % - (layer, self._rnn_cell_name)) - else: - return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % - (layer, self._rnn_cell_name)) - - def _OpaqueParamsToCanonical(self): + def _opaque_to_cu_canonical(self, opaque_param): """Converts opaque params to Cudnn canonical format. + Args: + opaque_param: An opaque tensor storing cudnn rnn params (weights and + biases). Returns: 2 list for weights and biases respectively. """ @@ -351,14 +239,14 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, - params=self._variables, + params=opaque_param, num_params=self._num_params, rnn_mode=self._rnn_mode, input_mode=self._input_mode, direction=self._direction) return (weights, biases) - def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): + def _cu_canonical_to_opaque(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. Args: @@ -378,7 +266,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): input_mode=self._input_mode, direction=self._direction) - def _TransformCanonical(self, cu_weights, cu_biases): + def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. The elements of argument lists are laid out in the following format: @@ -398,46 +286,43 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): cu_weights: a list of tensors of Cudnn canonical weights. cu_biases: a list of tensors of Cudnn canonical biases. Returns: - 2 tuples, one for weights and the other for bias. - Each tuple has two lists: the 1st for transformed tf canonical tensors - and the 2nd for the names of the tensors under which they are saved. + 1 tuple, tf canonical weights and biases. """ tf_weights, tf_biases = [], [] - tf_weights_names, tf_bias_names = [], [] layer_weights_num = self._num_params_per_layer * self._num_dirs layer_biases_num = layer_weights_num for i in range(self._num_layers): - layer_weights = cu_weights[i * layer_weights_num: - (i + 1) * layer_weights_num] + layer_weights = cu_weights[i * layer_weights_num:(i + 1) * + layer_weights_num] layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num] if self._direction == CUDNN_RNN_UNIDIRECTION: - prefix = self._TFCanonicalNamePrefix(i) - self._TransformSingleLayerCanonical(layer_weights, layer_biases, prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) + self._cu_canonical_to_tf_canonical_single_layer( + layer_weights, layer_biases, tf_weights, tf_biases) else: - fw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=True) - bw_prefix = self._TFCanonicalNamePrefix(i, is_fwd=False) - fw_weights = layer_weights[:len(layer_weights) // 2] bw_weights = layer_weights[len(layer_weights) // 2:] fw_biases = layer_biases[:len(layer_biases) // 2] bw_biases = layer_biases[len(layer_biases) // 2:] - self._TransformSingleLayerCanonical(fw_weights, fw_biases, fw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - - self._TransformSingleLayerCanonical(bw_weights, bw_biases, bw_prefix, - tf_weights, tf_weights_names, - tf_biases, tf_bias_names) - return (tf_weights, tf_weights_names), (tf_biases, tf_bias_names) - - def _TransformSingleLayerCanonical(self, cu_weights, cu_biases, prefix, - tf_weights, tf_weights_names, tf_biases, - tf_bias_names): + self._cu_canonical_to_tf_canonical_single_layer( + fw_weights, + fw_biases, + tf_weights, + tf_biases, + ) + + self._cu_canonical_to_tf_canonical_single_layer( + bw_weights, + bw_biases, + tf_weights, + tf_biases, + ) + return (tf_weights, tf_biases) + + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): r"""Transform single layer Cudnn canonicals to tf canonicals. The elements of cu_weights, cu_biases are laid out in the following format: @@ -447,15 +332,12 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Args: cu_weights: a list of tensors, single layer weights. cu_biases: a list of tensors, single layer biases. - prefix: the shared prefix of all tensor names. tf_weights: a list where transformed weights are stored. - tf_weights_names: a list where names of transformed weights are stored. tf_biases: a list where transformed biases are stored. - tf_bias_names: a list where names of transformed biases are stored. """ raise NotImplementedError("Abstract method") - def _ReverseTransformCanonical(self, tf_canonicals): + def _tf_canonical_to_cu_canonical(self, tf_canonicals): r"""Transform from tf canonical to Cudnn canonical. This is the reverse routine of _TransformCanonical(). @@ -502,30 +384,27 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return cu_weights, cu_biases def _cudnn_to_tf_weights(self, *cu_weights): - r"""Stitching cudnn canonical weights to generate tf canonical weights.""" + r"""Stitches cudnn canonical weights to generate tf canonical weights.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_weights(self, layer, *tf_weights): - r"""Reverse the operations in StitchWeights().""" + r"""Reverses the operations in StitchWeights().""" raise NotImplementedError("Abstract method") def _cudnn_to_tf_biases(self, *biases): - r"""Stitching cudnn canonical biases to generate tf canonical biases.""" + r"""Stitches cudnn canonical biases to generate tf canonical biases.""" raise NotImplementedError("Abstract method") def _tf_to_cudnn_biases(self, *tf_biases): - r"""Reverse the operations in StitchBiases().""" + r"""Reverses the operations in StitchBiases().""" raise NotImplementedError("Abstract method") -class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn LSTM opaque params.""" - +class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF LSTM.""" _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) - def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order return [i_g, c_g, f_g, o_g] @@ -603,44 +482,16 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): # Return ifco order for Cudnn LSTM. return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): - (w,) = self._cudnn_to_tf_weights(*weights) - (b,) = self._cudnn_to_tf_biases(*biases) - + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): + (w,) = self._cudnn_to_tf_weights(*cu_weights) + (b,) = self._cudnn_to_tf_biases(*cu_biases) tf_weights.append(w) - tf_weights_names.append(prefix + "/kernel") - tf_biases.append(b) - tf_bias_names.append(prefix + "/bias") - - def _checkpointable_track_params(self, checkpointable, params): - """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" - biases = [] - weights = [] - for name in self._weight_names: - weights.append(params[name]) - for name in self._bias_names: - biases.append(params[name]) - assert len(params) == len(weights) + len(biases) - if len(weights) == 1 and len(biases) == 1: - # For single-layer cells, allow substituting a cell with no MultiRNNCell - # wrapping. - kernel, = weights # pylint: disable=unbalanced-tuple-unpacking - bias, = biases # pylint: disable=unbalanced-tuple-unpacking - checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access - checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access - assert len(biases) == len(weights) - for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): - cell = checkpointable_lib.Checkpointable() - checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access - cell.bias = bias - cell.kernel = kernel -class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): - """SaveableObject implementation handling Cudnn GRU opaque params.""" +class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter): + """Helper class that converts between params of Cudnn and TF GRU.""" _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER @@ -702,29 +553,18 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): b_ri, b_rr = array_ops.split(br, 2, axis=0) return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh - def _TransformSingleLayerCanonical(self, weights, biases, prefix, tf_weights, - tf_weights_names, tf_biases, - tf_bias_names): + def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases, + tf_weights, tf_biases): # pylint: disable=invalid-name - W_ir, w_h, r_h = self._cudnn_to_tf_weights(*weights) - b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*biases) + W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights) + b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases) # pylint: enable=invalid-name - tf_weights.extend([W_ir, w_h, r_h]) - tf_weights_names.append(prefix + "/gates/kernel") - tf_weights_names.append(prefix + "/candidate/input_projection/kernel") - tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") - tf_biases.extend([b_ir, b_wh, b_rh]) - tf_bias_names.append(prefix + "/gates/bias") - tf_bias_names.append(prefix + "/candidate/input_projection/bias") - tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") - -class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) +class CudnnParamsFormatConverterBasic(CudnnParamsFormatConverterLSTM): + """Helper class that converts between params of Cudnn and TF Relu/Tanh RNN.""" def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -766,18 +606,270 @@ class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): return b_i, b_h -class CudnnRNNTanhSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" +class CudnnParamsFormatConverterTanh(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Tanh RNN.""" _rnn_mode = CUDNN_RNN_TANH _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER -class CudnnRNNReluSaveable(CudnnRNNSimpleSaveable): - """SaveableObject implementation handling Cudnn RNN Relu opaque params.""" +class CudnnParamsFormatConverterRelu(CudnnParamsFormatConverterBasic): + """Helper class that converts between params of Cudnn and TF Relu RNN.""" _rnn_mode = CUDNN_RNN_RELU _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER +# TODO(yaozhang): make sure we only save the canonical version of params and +# don't save the platform-specific version to avoid potential race +# conditions where params is updated by both versions when being restored. +# Currently, checkpointing will function properly, despite that we save both +# versions, because Saver restores customized savables after Variables. +# However, it is good to not rely on this restoring order of Saver and to +# avoid unnecessary storage. Add a test to check only the canonical version is +# saved. +class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): + """Abstract SaveableObject implementation handling Cudnn opaque params.""" + + def __init__(self, + opaque_params, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + scope=None, + name="cudnn_rnn_saveable"): + """Creates a CudnnOpaqueParamsSaveable object. + + CudnnOpaqueParamsSaveable is saveable/restorable in a checkpoint file + and is used to save/restore the weights and biases parameters in a + canonical format which is directly consumable by platform-independent tf + RNN cells. Parameters are saved as tensors layer by layer with weight + tensors followed by bias tensors, and forward direction followed by + backward direction (if applicable). When restoring, a user could name + param_variables as desired, and restore weight and bias tensors to these + variables. + + For CudnnRNNRelu or CudnnRNNTanh, there are 2 tensors per weight and per + bias for each layer: tensor 0 is applied to the input from the previous + layer and tensor 1 to the recurrent input. + + For CudnnLSTM, there are 8 tensors per weight and per bias for each + layer: tensor 0-3 are applied to the input from the previous layer and + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + + For CudnnGRU, there are 6 tensors per weight and per bias for each layer: + tensor 0-2 are applied to the input from the previous layer and + tensor 3-5 to the recurrent input. Tensor 0 and 3 are for the reset gate; + tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. + + Args: + opaque_params: a variable, Cudnn RNN opaque params. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. 'linear_input' (default) + always applies a linear projection of input onto RNN hidden state. + (standard RNN behavior). 'skip_input' is only allowed when input_size == + num_units; 'auto_select' implies 'skip_input' when input_size == + num_units; otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + scope: string of VariableScope, the scope of equivalent subgraph + consisting only platform-independent tf RNN cells. + name: the name of the CudnnOpaqueParamsSaveable object. + """ + # Define in subclasses. + self._num_layers = num_layers + self._input_size = input_size + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + if scope is not None: + scope_name = scope.name if isinstance(scope, vs.VariableScope) else scope + self._scope = scope_name or None + else: + self._scope = None + + self._variables = opaque_params + self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + # Defined in subclasses. + self._format_converter = None + + tf_weights, tf_biases = ( + self.format_converter.opaque_to_tf_canonical(self._variables)) + tf_weight_names, tf_bias_names = self._tf_canonical_names() + # We currently don't use slice_spec. It might be useful in a distributed + # setting where each parameter server node stores a slice of variable, + # instead of having the master pull all slices and then save them. + slice_spec = "" + params = tf_weights + tf_biases + self._weight_names = tf_weight_names + self._bias_names = tf_bias_names + self._param_names = tf_weight_names + tf_bias_names + prefixed_param_names = tf_weight_names + tf_bias_names + if self._scope: + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names + ] + specs = [ + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, prefixed_param_names) + ] + super(CudnnOpaqueParamsSaveable, self).__init__( + array_ops.identity(self._variables), specs, name) + + @property + def format_converter(self): + if self._format_converter is None: + self._format_converter = self._format_converter_cls( + self._num_layers, self._num_units, self._input_size, self._input_mode, + self._direction) + return self._format_converter + + def restore(self, restored_tensors, restored_shapes): + opaque_params = self.format_converter.tf_canonical_to_opaque( + restored_tensors) + return state_ops.assign( + self._variables, opaque_params, validate_shape=False) + + def _checkpointable_save(self, save_buffer): + weights, biases = self.format_converter.opaque_to_tf_canonical( + self._variables) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [ + array_ops.identity(restore_buffer[name]) for name in self._param_names + ] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = split_dependency.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + + def _tf_canonical_names(self): + tf_weights_names, tf_biases_names = [], [] + for i in range(self._num_layers): + if self._direction == CUDNN_RNN_UNIDIRECTION: + prefix = self._tf_canonical_name_prefix(i) + self._tf_canonical_names_single_layer(prefix, tf_weights_names, + tf_biases_names) + else: + fwd_prefix = self._tf_canonical_name_prefix(i, is_fwd=True) + bak_prefix = self._tf_canonical_name_prefix(i, is_fwd=False) + + self._tf_canonical_names_single_layer(fwd_prefix, tf_weights_names, + tf_biases_names) + self._tf_canonical_names_single_layer(bak_prefix, tf_weights_names, + tf_biases_names) + return tf_weights_names, tf_biases_names + + def _tf_canonical_name_prefix(self, layer, is_fwd=True): + if self._direction == CUDNN_RNN_UNIDIRECTION: + return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) + else: + if is_fwd: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/fw/%s" % + (layer, self._rnn_cell_name)) + else: + return ("stack_bidirectional_rnn/cell_%d/bidirectional_rnn/bw/%s" % + (layer, self._rnn_cell_name)) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_biases_names): + raise NotImplementedError("Abstract method") + + +class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn LSTM opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterLSTM + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/kernel") + tf_bias_names.append(prefix + "/bias") + + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + + +class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): + """SaveableObject implementation handling Cudnn GRU opaque params.""" + + _format_converter_cls = CudnnParamsFormatConverterGRU + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) + + def _tf_canonical_names_single_layer(self, prefix, tf_weights_names, + tf_bias_names): + tf_weights_names.append(prefix + "/gates/kernel") + tf_weights_names.append(prefix + "/candidate/input_projection/kernel") + tf_weights_names.append(prefix + "/candidate/hidden_projection/kernel") + + tf_bias_names.append(prefix + "/gates/bias") + tf_bias_names.append(prefix + "/candidate/input_projection/bias") + tf_bias_names.append(prefix + "/candidate/hidden_projection/bias") + + +class CudnnRNNTanhSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterTanh + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + +class CudnnRNNReluSaveable(CudnnLSTMSaveable): + _format_converter_cls = CudnnParamsFormatConverterRelu + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) + + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and training. But it is possible that the layout of the parameter buffers @@ -850,7 +942,7 @@ def _get_num_params(rnn_mode, num_layers, direction): elif rnn_mode == CUDNN_RNN_TANH: num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: - raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) + raise ValueError("Invalid \'rnn_mode\': %s" % rnn_mode) num_params = num_layers * num_params_per_layer if direction != CUDNN_RNN_UNIDIRECTION: num_params *= 2 @@ -918,7 +1010,7 @@ def _cudnn_rnn(inputs, "seed2": seed2, "name": name } - if use_cudnn_v2 is not "1": + if use_cudnn_v2 != "1": outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) else: outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) @@ -1582,7 +1674,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): """ if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s", direction) + raise ValueError("Invalid direction: %s" % direction) super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py index 0456463a1928cf226010670b90a5d574579e0411..6c5f8c6b00975b3fba041271309a93cecd9f5057 100644 --- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py @@ -46,7 +46,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -88,7 +88,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -115,9 +115,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((3, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -142,7 +141,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): tensor_shape.TensorShape((3, 4))) self.assertEqual(actual_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -184,7 +183,7 @@ class AssertElementShapeTest(test_base.DatasetTestBase): result = dataset.apply(batching.assert_element_shape(expected_shapes)) self.assertEqual(expected_shapes, result.output_shapes) - iterator = result.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(result) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: @@ -211,9 +210,8 @@ class AssertElementShapeTest(test_base.DatasetTestBase): wrong_shapes = (tensor_shape.TensorShape(2), tensor_shape.TensorShape((None, 10))) - iterator = ( - dataset.apply(batching.assert_element_shape(wrong_shapes)) - .make_initializable_iterator()) + iterator = dataset_ops.make_initializable_iterator( + dataset.apply(batching.assert_element_shape(wrong_shapes))) init_op = iterator.initializer get_next = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index d2a72272db159755ac2d741bcdbce9ec646d928e..b9840b1ff1a3df5a05db0e64f436637220f49f80 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -23,6 +23,7 @@ import shutil from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -48,7 +49,7 @@ class LMDBDatasetTest(test_base.DatasetTestBase): num_repeats = 2 dataset = readers.LMDBDataset(filenames).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator(dataset) init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index c5a786232252432481566e3cde23e9310df172cc..2527706709fae8e459aca3489324d4db3c784be6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -63,13 +63,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> # _SlideDataset(window_size, window_shift, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -127,13 +127,13 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(count) -> _SlideDataset(window_size, stride, window_stride). - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(count).apply( sliding.sliding_window_batch( window_size=window_size_t, stride=stride_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer get_next = iterator.get_next() @@ -173,12 +173,12 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): window_shift_t = array_ops.placeholder(dtypes.int64, shape=[]) window_stride_t = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply( sliding.sliding_window_batch( window_size=window_size_t, window_shift=window_shift_t, - window_stride=window_stride_t)).make_initializable_iterator()) + window_stride=window_stride_t))) init_op = iterator.initializer with self.cached_session() as sess: @@ -204,9 +204,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -233,9 +233,9 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): values=array_ops.fill([math_ops.to_int32(i)], i), dense_shape=[i]) - iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( - sliding.sliding_window_batch( - window_size=5, window_shift=3)).make_initializable_iterator() + iterator = dataset_ops.make_initializable_iterator( + dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(window_size=5, window_shift=3))) init_op = iterator.initializer get_next = iterator.get_next() @@ -265,11 +265,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): return sparse_tensor.SparseTensorValue( indices=[[0]], values=(i * [1]), dense_shape=[1]) - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.range(10).map(_sparse).apply( sliding.sliding_window_batch(window_size=4, window_shift=2)).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) init_op = iterator.initializer get_next = iterator.get_next() @@ -305,11 +304,10 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): yield [4.0, 5.0, 6.0] yield [7.0, 8.0, 9.0, 10.0] - iterator = ( + iterator = dataset_ops.make_initializable_iterator( dataset_ops.Dataset.from_generator( generator, dtypes.float32, output_shapes=[None]).apply( - sliding.sliding_window_batch(window_size=3, window_shift=1)) - .make_initializable_iterator()) + sliding.sliding_window_batch(window_size=3, window_shift=1))) next_element = iterator.get_next() with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 34dc2379d0cb38f8f6962fa42efe21b793bc8d65..0fb406f1167053a128646c5c692986b0ce016f1e 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -188,8 +188,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:function", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", + "//tensorflow/python/data/util:structure", ], ) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4601376dff47e161962e92678883039c4b88bab7..c0152156a1ba70297adb7054622b15ca04f859cd 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -21,10 +21,9 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.experimental.ops import readers from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util import deprecation @@ -355,7 +354,7 @@ def read_batch_features(file_pattern, shuffle=randomize_input, num_epochs=num_epochs, shuffle_buffer_size=capacity) - iterator = dataset.make_one_shot_iterator() + iterator = dataset_ops.make_one_shot_iterator(dataset) outputs = iterator.get_next() return outputs @@ -379,15 +378,13 @@ class LMDBDataset(dataset_ops.DatasetSource): (key value) pairs sequentially. For example: ```python + tf.enable_eager_execution() + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break + for key, value in dataset: + print(key, value) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. @@ -398,18 +395,10 @@ class LMDBDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): return gen_experimental_dataset_ops.experimental_lmdb_dataset( - self._filenames, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - @property - def output_classes(self): - return ops.Tensor, ops.Tensor - - @property - def output_shapes(self): - return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + self._filenames, **dataset_ops.flat_structure(self)) @property - def output_types(self): - return dtypes.string, dtypes.string + def _element_structure(self): + return structure.NestedStructure( + (structure.TensorStructure(dtypes.string, []), + structure.TensorStructure(dtypes.string, []))) diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index bcc383587c54bd89502313f9328bc06c49046a87..5c6ee6bfdc7167d14b292f8f763adafca4e3a72c 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -18,11 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest +from tensorflow.python.data.util import structure from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.util import deprecation @@ -40,8 +39,13 @@ class _SlideDataset(dataset_ops.UnaryDataset): self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") + input_structure = structure.convert_legacy_structure( + input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + self._structure = input_structure._batch(None) # pylint: disable=protected-access + def _as_variant_tensor(self): - return gen_dataset_ops.slide_dataset( + return ged_ops.experimental_sliding_window_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, @@ -49,20 +53,8 @@ class _SlideDataset(dataset_ops.UnaryDataset): **dataset_ops.flat_structure(self)) @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - input_shapes = self._input_dataset.output_shapes - return nest.pack_sequence_as(input_shapes, [ - tensor_shape.vector(None).concatenate(s) - for s in nest.flatten(self._input_dataset.output_shapes) - ]) - - @property - def output_types(self): - return self._input_dataset.output_types + def _element_structure(self): + return self._structure @deprecation.deprecated_args( diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index a87a5624c88d1d0af10055261dad55937ed6aeb0..3ecd755d86f6be47910aebbdb46d335d165427d8 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -26,7 +26,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/distribute/python:collective_all_reduce_strategy", - "//tensorflow/contrib/distribute/python:cross_tower_ops", "//tensorflow/contrib/distribute/python:mirrored_strategy", "//tensorflow/contrib/distribute/python:monitor", "//tensorflow/contrib/distribute/python:one_device_strategy", @@ -35,6 +34,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:distribute_config", "//tensorflow/python/distribute:distribute_coordinator", ], diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 2e025765e4aaab7114aa6e3e79336e48a71b5b55..8a8dc159ade6f2a4a9b5ec29055ea4848492b29f 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -20,7 +20,7 @@ on many GPUs on one machine. Essentially, we create copies of all variables in the model's layers on each device. We then use all-reduce to combine gradients across the devices before applying them to the variables to keep them in sync. * [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy): -This is a version of `MirroredStrategy` for multi-working training. It uses +This is a version of `MirroredStrategy` for multi-worker training. It uses a collective op to do all-reduce. This supports between-graph communication and synchronization, and delegates the specifics of the all-reduce implementation to the runtime (as opposed to encoding it in the graph). This allows it to perform @@ -31,8 +31,8 @@ fault-tolerance to allow training to continue when there is worker failure. * [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy): This strategy supports using parameter servers either for multi-GPU local training or asynchronous multi-machine training. When used to train locally, -variables are not mirrored, instead they placed on the CPU and operations are -replicated across all local GPUs. In a multi-machine setting, some are +variables are not mirrored, instead they are placed on the CPU and operations +are replicated across all local GPUs. In a multi-machine setting, some are designated as workers and some as parameter servers. Each variable is placed on one parameter server. Computation operations are replicated across all GPUs of the workers. @@ -46,6 +46,9 @@ Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` Take a very simple model consisting of a single layer: ```python +import tensorflow as tf +from tensorflow import keras + inputs = tf.keras.layers.Input(shape=(1,)) predictions = tf.keras.layers.Dense(1)(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=predictions) @@ -90,8 +93,8 @@ Similarly, we can also call `evaluate` and `predict` as before using appropriate datasets. ```python -model.evaluate(eval_dataset) -model.predict(predict_dataset) +model.evaluate(eval_dataset, steps=1) +model.predict(predict_dataset, steps=1) ``` That's all you need to train your model with Keras on multiple GPUs with @@ -131,7 +134,7 @@ def model_fn(features, labels, mode): return tf.estimator.EstimatorSpec(mode, loss=loss) if mode == tf.estimator.ModeKeys.TRAIN: - train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) ``` @@ -190,7 +193,7 @@ in the input function gives a solid boost in performance. When using For multi-worker training, no code change is required to the `Estimator` code. You can run the same model code for all tasks in your cluster including parameter servers and the evaluator. But you need to use -`tf.estimator.train_and_evaluator`, explicitly specify `num_gpus_per_workers` +`tf.estimator.train_and_evaluate`, explicitly specify `num_gpus_per_workers` for your strategy object, and set "TF\_CONFIG" environment variables for each binary running in your cluster. We'll provide a Kubernetes template in the [tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets @@ -245,19 +248,17 @@ Let's use the same example for multi-worker. We'll start a cluster with 3 workers doing synchronous all-reduce training. In the following code snippet, we start multi-worker training using `tf.estimator.train_and_evaluate`: - ```python def model_main(): - estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig(train_distribute=distribution) + estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) ``` - **Note**: You don't have to set "TF\_CONFIG" manually if you use our provided Kubernetes template. @@ -324,13 +325,13 @@ start training. On your laptop, you can run ```python -estimator = ... distribution = tf.contrib.distribute.CollectiveAllReduceStrategy( num_gpus_per_worker=2) config = tf.estimator.RunConfig( experimental_distribute=tf.contrib.distribute.DistributeConfig( train_distribute=distribution, remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]})) +estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) train_spec = tf.estimator.TrainSpec(input_fn=input_fn) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 823fe6a917f4f31ab6822e4bb1130d62ff45f0c9..8ec73654e30e4967f318c558ba94301e84a206e4 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -25,13 +25,13 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.distribute.python.collective_all_reduce_strategy import CollectiveAllReduceStrategy -from tensorflow.contrib.distribute.python.cross_tower_ops import * from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy from tensorflow.contrib.distribute.python.monitor import Monitor from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.cross_device_ops import * from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server from tensorflow.python.training.distribute import * @@ -41,27 +41,30 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'AllReduceCrossTowerOps', + 'AllReduceCrossDeviceOps', 'CollectiveAllReduceStrategy', - 'CrossTowerOps', + 'CrossDeviceOps', 'DistributeConfig', 'DistributionStrategy', + 'DistributionStrategyExtended', 'MirroredStrategy', 'Monitor', + 'MultiWorkerAllReduce', 'OneDeviceStrategy', 'ParameterServerStrategy', - 'ReductionToOneDeviceCrossTowerOps', + 'ReductionToOneDeviceCrossDeviceOps', 'Step', 'StandardInputStep', 'StandardSingleLossStep', - 'TowerContext', + 'ReplicaContext', 'TPUStrategy', - 'get_cross_tower_context', + 'get_cross_replica_context', 'get_distribution_strategy', 'get_loss_reduction', - 'get_tower_context', + 'get_replica_context', 'has_distribution_strategy', - 'require_tower_context', + 'in_cross_replica_context', + 'require_replica_context', 'run_standard_tensorflow_server', 'UpdateContext', ] diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8267612236bcf2946c033d3e5071eee935d2c03a..4c9c35da5a36aa8149d15c8d1c25e4dfaa6a07c1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -16,45 +16,26 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") # TODO(priyag): Figure out testonly issues that are preventing us from # including our tests in pip for now. -py_library( - name = "values", - srcs = ["values.py"], - visibility = ["//tensorflow:internal"], - deps = [ - ":input_ops", - ":prefetching_ops_v2", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python/eager:context", - "//tensorflow/python/training/checkpointable:base", - "@six_archive//:six", - ], -) - cuda_py_test( name = "values_test", srcs = ["values_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:errors", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:device_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", - "//tensorflow/python:device_util", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", ], @@ -68,25 +49,9 @@ py_library( srcs = ["mirrored_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":shared_variable_creator", - ":values", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:device", - "//tensorflow/python:device_util", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:pywrap_tensorflow", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/distribute:multi_worker_util", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:values", ], ) @@ -95,16 +60,17 @@ py_library( srcs = ["parameter_server_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:cross_device_ops", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -116,7 +82,7 @@ cuda_py_test( ":combinations", ":multi_worker_test_base", ":parameter_server_strategy", - ":values", + ":strategy_test_lib", "@absl_py//absl/testing:parameterized", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -127,10 +93,12 @@ cuda_py_test( "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:session", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -145,12 +113,13 @@ py_library( srcs = ["one_device_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":values", - "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:distribute", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "@six_archive//:six", ], @@ -161,16 +130,16 @@ py_library( srcs = ["collective_all_reduce_strategy.py"], visibility = ["//tensorflow:internal"], deps = [ - ":cross_tower_ops", - ":cross_tower_utils", ":mirrored_strategy", - ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", ], ) @@ -187,11 +156,11 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:layers", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -212,10 +181,10 @@ py_library( ":tpu_strategy", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/optimizer_v2:training", - "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/eager:context", "@absl_py//absl/testing:parameterized", ], @@ -233,28 +202,6 @@ py_test( ], ) -py_test( - name = "mirrored_strategy_test", - srcs = ["mirrored_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":mirrored_strategy", - ":multi_worker_test_base", - ":strategy_test_lib", - "//tensorflow/python:constant_op", - "//tensorflow/python:distribute", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_test( name = "one_device_strategy_test", srcs = ["one_device_strategy_test.py"], @@ -270,35 +217,32 @@ py_test( ], ) +# TODO(priyag): Rename this test to mirrored_strategy_test cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ + ":combinations", ":mirrored_strategy", ":multi_worker_test_base", - ":values", ":strategy_test_lib", - "//tensorflow/python:distribute", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:state_ops", "//tensorflow/python:variable_scope", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python/distribute:distribute_lib", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 5, tags = [ "guitar", - "no_pip", "multi_and_single_gpu", - # Do not perform the extra analysis on this test, because it is already - # performed for the `:mirrored_strategy_test` target. - "no_oss", - "noasan", - "notap", - "notsan", + "no_pip", ], ) @@ -315,6 +259,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:session", + "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", "//third_party/py/numpy", ], @@ -336,12 +281,15 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":one_device_strategy", - ":values", "//tensorflow/contrib/tpu:tpu_lib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:values", ], ) @@ -351,7 +299,6 @@ cuda_py_test( additional_deps = [ ":collective_all_reduce_strategy", ":combinations", - ":cross_tower_utils", ":multi_worker_test_base", ":strategy_test_lib", "@absl_py//absl/testing:parameterized", @@ -367,6 +314,7 @@ cuda_py_test( "//tensorflow/python:layers", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", ], @@ -411,6 +359,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "moving_averages_test", + srcs = ["moving_averages_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python/eager:test", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + tags = [ + "no_pip", + ], +) + cuda_py_test( name = "optimizer_v2_test", srcs = ["optimizer_v2_test.py"], @@ -448,15 +414,31 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "no_oss", # http://b/119349471 + "no_pip", + "tf_integration_test", + ], +) + +cuda_py_test( + name = "keras_optimizer_v2_test", + srcs = ["keras_optimizer_v2_test.py"], + additional_deps = [ + ":keras_test_lib", + ], + tags = [ + "multi_and_single_gpu", + "no_oss", # http://b/119349471 "no_pip", + "tf_integration_test", ], ) cuda_py_test( name = "estimator_training_test", - size = "enormous", srcs = ["estimator_training_test.py"], additional_deps = [ + ":collective_all_reduce_strategy", ":combinations", ":mirrored_strategy", ":multi_worker_test_base", @@ -464,7 +446,9 @@ cuda_py_test( "//third_party/py/numpy", "//tensorflow/contrib/optimizer_v2:training", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute", + "//tensorflow/python/distribute:distribute_config", + "//tensorflow/python/distribute:distribute_coordinator", + "//tensorflow/python/distribute:distribute_coordinator_context", "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/feature_column", @@ -472,9 +456,15 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:summary", ], + shard_count = 48, tags = [ "multi_and_single_gpu", "no_pip", + # TODO(b/118768923): Re-enable {a,m,t}san test. + "noasan", + "nomsan", + "notsan", + "no_oss", # http://b/119349471 ], ) @@ -550,52 +540,16 @@ cuda_py_test( ], ) -py_library( - name = "shared_variable_creator", - srcs = ["shared_variable_creator.py"], - visibility = ["//tensorflow:internal"], -) - -py_test( - name = "shared_variable_creator_test", - srcs = ["shared_variable_creator_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":shared_variable_creator", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:test", - ], -) - -py_library( - name = "cross_tower_utils", - srcs = ["cross_tower_utils.py"], - srcs_version = "PY2AND3", - deps = [ - ":values", - "//tensorflow/contrib/all_reduce:all_reduce_py", - "//tensorflow/contrib/nccl:nccl_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:collective_ops", - "//tensorflow/python:device", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - ], -) - cuda_py_test( - name = "cross_tower_utils_test", - srcs = ["cross_tower_utils_test.py"], + name = "cross_device_utils_test", + srcs = ["cross_device_utils_test.py"], additional_deps = [ ":combinations", - ":cross_tower_utils", "@absl_py//absl/testing:parameterized", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_utils", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -604,41 +558,20 @@ cuda_py_test( ], ) -py_library( - name = "cross_tower_ops", - srcs = ["cross_tower_ops.py"], - srcs_version = "PY2AND3", - deps = [ - ":cross_tower_utils", - ":values", - "//tensorflow/python:array_ops", - "//tensorflow/python:device_lib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - "@six_archive//:six", - ], -) - cuda_py_test( - name = "cross_tower_ops_test", - size = "large", - srcs = ["cross_tower_ops_test.py"], + name = "cross_device_ops_test", + srcs = ["cross_device_ops_test.py"], additional_deps = [ ":combinations", - ":cross_tower_ops", ":multi_worker_test_base", ":mirrored_strategy", - ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/distribute:cross_device_ops", + "//tensorflow/python/distribute:values", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], @@ -648,63 +581,6 @@ cuda_py_test( ], ) -py_library( - name = "prefetching_ops_v2", - srcs = ["prefetching_ops_v2.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - -cuda_py_test( - name = "prefetching_ops_v2_test", - srcs = ["prefetching_ops_v2_test.py"], - additional_deps = [ - ":prefetching_ops_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - -py_library( - name = "input_ops", - srcs = ["input_ops.py"], - visibility = ["//tensorflow:internal"], - deps = [ - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/util:nest", - ], -) - -cuda_py_test( - name = "input_ops_test", - srcs = ["input_ops_test.py"], - additional_deps = [ - ":input_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:errors", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:io_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python:util", - ], - tags = [ - "no_pip", - ], -) - py_library( name = "keras_test_lib", testonly = 1, @@ -715,6 +591,7 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", + "//tensorflow/python/eager:test", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", "//third_party/py/numpy", @@ -731,6 +608,7 @@ cuda_py_test( shard_count = 16, tags = [ "multi_and_single_gpu", + "no_oss", # TODO(b/117919883): Fix python error. "no_pip", "no_windows_gpu", "notsan", @@ -743,7 +621,6 @@ py_library( srcs = ["metrics_v1_test.py"], deps = [ ":combinations", - "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 865dba803f562e0ab98341dd8343e3c72b03d39b..31bd0e996a247a2fc01405fb3b8172a40853d698 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -43,10 +43,12 @@ class CheckpointUtilsWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - in_tower_mode=[True, False], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + in_replica_mode=[True, False], mode=["graph"])) - def testInitFromCheckpoint(self, distribution, in_tower_mode): + def testInitFromCheckpoint(self, distribution, in_replica_mode): checkpoint_dir = self.get_temp_dir() with self.cached_session() as session: v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( @@ -68,8 +70,8 @@ class CheckpointUtilsWithDistributionStrategyTest( self.assertAllEqual(v2_value, self.evaluate(v2)) with ops.Graph().as_default() as g, distribution.scope(): - if in_tower_mode: - distribution.call_for_each_tower(init_and_verify, g) + if in_replica_mode: + distribution.call_for_each_replica(init_and_verify, args=[g]) else: init_and_verify(g) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 9809204f8f107270b5a7b51e65e06afdae7d96b8..5c50a20490482856becedf7b1379d2a0583d9a11 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,12 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -32,7 +36,7 @@ from tensorflow.python.platform import tf_logging as logging # TODO(yuefengz): support in-graph replication. -class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): +class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy): """Distribution strategy that uses collective ops for all-reduce. It is similar to the MirroredStrategy but it uses collective ops for @@ -53,6 +57,17 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): num_gpus_per_worker: number of local GPUs or GPUs per worker, the default is 0 meaning CPU only. """ + super(CollectiveAllReduceStrategy, self).__init__( + CollectiveAllReduceExtended(self, num_gpus_per_worker)) + + +class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): + """Implementation of CollectiveAllReduceStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + distribute_lib.DistributionStrategyExtended.__init__( + self, container_strategy) + self._cross_device_ops = None self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local_worker(num_gpus_per_worker) @@ -62,19 +77,19 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._num_workers = 1 if num_gpus_per_worker: - local_devices = [ + local_devices = tuple( "/device:GPU:%d" % i for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = ["/device:CPU:0"] + local_devices = ("/device:CPU:0",) + self._worker_device = device_util.canonicalize("/device:CPU:0") - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=1, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) self._cluster_spec = None self._task_type = None @@ -89,13 +104,12 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if task_type is None or task_id is None: raise ValueError("When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`") - if task_type not in ["chief", "worker"]: + if task_type not in ("chief", "worker"): raise ValueError( "Unrecognized task_type: %r, valid task types are: \"chief\", " "\"worker\"." % task_type) cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len( - cluster_spec.as_dict().get("chief", [])) + self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type) if not self._num_workers: raise ValueError("No `worker` or `chief` tasks can be found in " "`cluster_spec`.") @@ -103,22 +117,21 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) - worker_device = "/job:%s/task:%d" % (task_type, task_id) + self._worker_device = "/job:%s/task:%d" % (task_type, task_id) if num_gpus_per_worker: - local_devices = [ - "%s/device:GPU:%d" % (worker_device, i) + local_devices = tuple( + "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - local_devices = [worker_device] + local_devices = (self._worker_device,) - self._collective_keys = cross_tower_utils.CollectiveKeys() - super(CollectiveAllReduceStrategy, self).__init__( - devices=local_devices, - cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce( - num_workers=self._num_workers, - num_gpus_per_worker=num_gpus_per_worker, - collective_keys=self._collective_keys)) + self._collective_keys = cross_device_utils.CollectiveKeys() + self._initialize_local(local_devices) + self._cross_tower_ops = cross_device_ops_lib.CollectiveAllReduce( + num_workers=self._num_workers, + num_gpus_per_worker=num_gpus_per_worker, + collective_keys=self._collective_keys) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -160,7 +173,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on towers with id > 0 to + # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) @@ -202,17 +215,40 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return mirrored_strategy._create_mirrored_variable( devices, _real_mirrored_creator, *args, **kwargs) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" # TODO(yuefengz): shard the dataset. - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._devices, True) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec is None: + input_pipeline_id = 0 + else: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + input_context = distribute_lib.InputContext( + num_input_pipelines=self._num_workers, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + + return values.InputFunctionIterator( + input_fn, [(self._worker_device, self._devices)], [input_context]) + + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the object. Args: @@ -232,8 +268,25 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec, task_type, task_id) - if not session_config or not self._cluster_spec: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) + + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + # Enable the scoped allocator optimization for CollectiveOps. This + # optimization converts many small all-reduces into fewer larger + # all-reduces. + rewrite_options = updated_config.graph_options.rewrite_options + rewrite_options.scoped_allocator_optimization = ( + rewriter_config_pb2.RewriterConfig.ON) + # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op = + # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we + # clear and then append. + del rewrite_options.scoped_allocator_opts.enable_op[:] + rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") + + if not self._cluster_spec: + return updated_config assert self._task_type assert self._task_id is not None @@ -241,34 +294,28 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): # Collective group leader is needed for collective ops to coordinate # workers. if "chief" in self._cluster_spec.jobs: - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:chief/replica:0/task:0") else: if "worker" not in self._cluster_spec.jobs: raise ValueError( "You must have `chief` or `worker` jobs in the `cluster_spec`.") - session_config.experimental.collective_group_leader = ( + updated_config.experimental.collective_group_leader = ( "/job:worker/replica:0/task:0") # The device filters prevent communication between workers. - del session_config.device_filters[:] - session_config.device_filters.append( + del updated_config.device_filters[:] + updated_config.device_filters.append( "/job:%s/task:%d" % (self._task_type, self._task_id)) - # The scoped_allocator_optimization is to optimize graphs for collective - # ops. - rewrite_options = session_config.graph_options.rewrite_options - rewrite_options.scoped_allocator_optimization = ( - rewriter_config_pb2.RewriterConfig.ON) - del rewrite_options.scoped_allocator_opts.enable_op[:] - rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce") + return updated_config @property - def between_graph(self): + def experimental_between_graph(self): return True @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -278,3 +325,12 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): @property def should_save_summary(self): return self._is_chief + + @property + def _num_replicas_in_sync(self): + return len(self._devices) * self._num_workers + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 6796a23d464d344554ae9654e0992e30df5ad213..8a9e583f0afaac37a2057bae9b1ed79de43d68bc 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -23,13 +23,19 @@ import numpy as np from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops @@ -51,11 +57,6 @@ class CollectiveAllReduceStrategyTestBase( collective_key_base = 0 def setUp(self): - self._run_options = config_pb2.RunOptions() - self._run_options.experimental.collective_graph_key = 6 - - self._sess_config = config_pb2.ConfigProto() - # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different @@ -66,33 +67,38 @@ class CollectiveAllReduceStrategyTestBase( def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( num_gpus_per_worker=num_gpus) + session_config = config_pb2.ConfigProto() if task_type and task_id is not None: distribution.configure( - session_config=self._sess_config, + session_config=session_config, cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - collective_keys = cross_tower_utils.CollectiveKeys( + collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 * num_gpus + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + CollectiveAllReduceStrategyTestBase.collective_key_base) - distribution._collective_keys = collective_keys - distribution._cross_tower_ops._collective_keys = collective_keys + distribution.extended._collective_keys = collective_keys + distribution.extended._inferred_cross_device_ops._collective_keys = ( + collective_keys) if task_type and task_id is not None: - return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + return distribution, 'grpc://' + self._cluster_spec[task_type][ + task_id], session_config else: - return distribution, '' + return distribution, '', session_config def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) with ops.Graph().as_default(), \ - self.test_session(config=self._sess_config, - target=master_target) as sess, \ + self.cached_session(config=config, + target=master_target) as sess, \ d.scope(): - l = core.Dense(1, use_bias=False, name='gpu_%d' % d._num_gpus_per_worker) + l = core.Dense(1, use_bias=False, + name='gpu_%d' % d.extended._num_gpus_per_worker) def loss_fn(x): y = array_ops.reshape(l(x), []) - constant_op.constant(1.) @@ -117,7 +123,7 @@ class CollectiveAllReduceStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_tower(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=[one]) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -126,8 +132,8 @@ class CollectiveAllReduceStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -135,14 +141,13 @@ class CollectiveAllReduceStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) for i in range(10): - b, a = sess.run((before_out, after_out), options=self._run_options) + b, a = sess.run((before_out, after_out)) if i == 0: before, = b after, = a @@ -154,7 +159,8 @@ class CollectiveAllReduceStrategyTestBase( return error_after < error_before def _test_complex_model(self, task_type, task_id, num_gpus): - d, master_target = self._get_test_object(task_type, task_id, num_gpus) + d, master_target, config = self._get_test_object(task_type, task_id, + num_gpus) def model_fn(): """Mnist model with synthetic input.""" @@ -193,10 +199,10 @@ class CollectiveAllReduceStrategyTestBase( return train_op with ops.Graph().as_default(), \ - self.test_session(config=self._sess_config, - target=master_target) as sess: + self.cached_session(config=config, + target=master_target) as sess: with d.scope(): - train_op = d.call_for_each_tower(model_fn) + train_op = d.call_for_each_replica(model_fn) train_op = d.group(d.unwrap(train_op)) sess.run(variables.global_variables_initializer()) @@ -204,11 +210,11 @@ class CollectiveAllReduceStrategyTestBase( return True def _test_variable_initialization(self, task_type, task_id, num_gpus): - distribution, master_target = self._get_test_object(task_type, task_id, - num_gpus) + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.test_session(config=self._sess_config, - target=master_target) as sess, \ + self.cached_session(config=config, + target=master_target) as sess, \ distribution.scope(): def model_fn(): @@ -219,27 +225,55 @@ class CollectiveAllReduceStrategyTestBase( 1.0, 10.0, dtype=dtypes.float32)) return array_ops.identity(x) - x = distribution.call_for_each_tower(model_fn) - reduced_x = distribution.unwrap( - distribution.reduce( - variable_scope.VariableAggregation.MEAN, x, - destinations='/cpu:0'))[0] + x = distribution.call_for_each_replica(model_fn) + reduced_x = distribution.reduce(reduce_util.ReduceOp.MEAN, x) x = distribution.unwrap(x)[0] - sess.run( - variables.global_variables_initializer(), options=self._run_options) + sess.run(variables.global_variables_initializer()) - x_value, reduced_x_value = sess.run( - [x, reduced_x], options=self._run_options) + x_value, reduced_x_value = sess.run([x, reduced_x]) self.assertTrue( np.allclose(x_value, reduced_x_value, atol=1e-5), msg=('x_value = %r, reduced_x_value = %r' % (x_value, reduced_x_value))) return np.allclose(x_value, reduced_x_value, atol=1e-5) + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_object( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class DistributedCollectiveAllReduceStrategyTest( - CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): @classmethod def setUpClass(cls): @@ -247,6 +281,16 @@ class DistributedCollectiveAllReduceStrategyTest( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) + def test_num_replicas_in_sync(self): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=2) + distribution.configure(cluster_spec=self._cluster_spec, task_type='worker', + task_id=0) + num_workers = len(self._cluster_spec.get('chief', []) + + self._cluster_spec.get('worker', [])) + self.assertEqual(2 * num_workers, + distribution.num_replicas_in_sync) + @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testMinimizeLossGraph(self, num_gpus): @@ -257,7 +301,7 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, @@ -267,10 +311,56 @@ class DistributedCollectiveAllReduceStrategyTest( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) + # TODO(yuefengz): Update how we use num_gpus and required_gpus + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMakeInputFnIterator(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + # We use CPU as the device when num_gpus = 0 + devices_per_worker = max(1, num_gpus) + expected_values = [[i+j for j in range(devices_per_worker)] + for i in range(0, 100, devices_per_worker)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=3*devices_per_worker, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + def testUpdateConfigProto(self): + distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + rewrite_options = config_proto.graph_options.rewrite_options + rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed') + + new_config = distribution.update_config_proto(config_proto) + + # Verify group leader + self.assertEqual('/job:worker/replica:0/task:0', + new_config.experimental.collective_group_leader) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1'], new_config.device_filters) + + # Verify rewrite options. + new_rewrite_options = new_config.graph_options.rewrite_options + self.assertEqual(rewriter_config_pb2.RewriterConfig.ON, + new_rewrite_options.scoped_allocator_optimization) + self.assertEqual(['CollectiveReduce'], + new_rewrite_options.scoped_allocator_opts.enable_op) + class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -281,10 +371,6 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0, has_chief=True) - def setUp(self): - super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp() - self._run_options.experimental.collective_graph_key = 7 - @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testMinimizeLossGraph(self, num_gpus): @@ -310,21 +396,37 @@ class DistributedCollectiveAllReduceStrategyTestWithChief( self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy( - CollectiveAllReduceStrategyTestBase, parameterized.TestCase): +class LocalCollectiveAllReduceStrategy(CollectiveAllReduceStrategyTestBase, + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_minimize_loss_graph(None, None, num_gpus) def testComplexModel(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: - return + self.skipTest('Not enough GPUs') self._test_complex_model(None, None, num_gpus) + def testMakeInputFnIterator(self, num_gpus=2): + # Collective ops doesn't support strategy with one device. + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index cff4b0a463e43b1e63fa9e9e96a9df6ee193b506..365ce5cdec79f1914f0c9ccdf59a7dc59e6f819e 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -53,11 +53,11 @@ from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop from tensorflow.python.util import tf_inspect @@ -168,6 +168,8 @@ def _augment_with_special_arguments(test_method): if GPU_TEST: self.skipTest("Test that doesn't require GPUs.") elif context.num_gpus() < required_gpus: + # TODO(priyag): Consider allowing tests in graph mode using soft + # placement. self.skipTest( "{} GPUs are not available for this test. {} GPUs are available". format(required_gpus, context.num_gpus())) @@ -190,7 +192,7 @@ def _augment_with_special_arguments(test_method): kwargs_to_pass[arg] = kwargs[arg] if mode == "eager": - with ops.Graph().as_default(), context.eager_mode(): + with context.eager_mode(): if distribution: kwargs_to_pass["distribution"] = distribution.strategy test_method(**kwargs_to_pass) @@ -335,40 +337,58 @@ tpu_strategy_one_step = NamedDistribution( "TPUOneStep", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=1), required_tpu=True) -# Note that we disable prefetching for testing since prefetching makes -# the input non-deterministic. +mirrored_strategy_with_one_cpu = NamedDistribution( + "Mirrored1CPU", + lambda: mirrored_lib.MirroredStrategy(["/cpu:0"])) +mirrored_strategy_with_one_gpu = NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_lib.MirroredStrategy(["/gpu:0"]), + required_gpus=1) mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/cpu:0"]), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - lambda: mirrored_lib.MirroredStrategy( - ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + lambda: mirrored_lib.MirroredStrategy(["/gpu:0", "/gpu:1"]), + required_gpus=2) +core_mirrored_strategy_with_one_cpu = NamedDistribution( + "CoreMirrored1CPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/cpu:0"])) +core_mirrored_strategy_with_one_gpu = NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "CoreMirroredCPUAndGPU", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/cpu:0"]), + required_gpus=1) +core_mirrored_strategy_with_two_gpus = NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_lib.CoreMirroredStrategy(["/gpu:0", "/gpu:1"]), required_gpus=2) -adam_optimizer_v1_fn = NamedObject( - "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) adagrad_optimizer_v1_fn = NamedObject( "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) +adam_optimizer_v1_fn = NamedObject("AdamV1", + lambda: adam.AdamOptimizer(0.001, epsilon=1)) rmsprop_optimizer_v1_fn = NamedObject( "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) -optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, - adagrad_optimizer_v1_fn] -adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) +optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn] + gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) adagrad_optimizer_v2_fn = NamedObject( "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) -optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn, - adagrad_optimizer_v2_fn] +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) + +optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] @@ -377,8 +397,11 @@ def distributions_and_v1_optimizers(): """A common set of combination with DistributionStrategies and Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v1) @@ -387,7 +410,10 @@ def distributions_and_v2_optimizers(): """DistributionStrategies and V2 Optimizers.""" return combine( distribution=[ - one_device_strategy, mirrored_strategy_with_gpu_and_cpu, - mirrored_strategy_with_two_gpus + one_device_strategy, + mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus, + core_mirrored_strategy_with_gpu_and_cpu, + core_mirrored_strategy_with_two_gpus, ], optimizer_fn=optimizers_v2) diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e9521c1c1115ffdbdcf375ad4017bacb962832 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py @@ -0,0 +1,580 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CrossDeviceOps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _make_per_replica(values, devices, regroup=False): + devices = cross_device_ops_lib.get_devices_from(devices) + assert len(values) == len(devices) + + # We simulate the result of regroup called on PerReplica which strips the + # PerReplica wrapper if it has only one value. + if len(values) == 1 and regroup: + with ops.device(devices[0]): + placed_v = array_ops.identity(values[0]) + return placed_v + + index = {} + for d, v in zip(devices, values): + with ops.device(d): + placed_v = array_ops.identity(v) + index[d] = placed_v + return value_lib.PerReplica(index) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _fake_mirrored(value, devices): + """Create a faked Mirrored object for testing. + + All components of the returned Mirrored have the same objects, which is not + true in reality. + """ + devices = cross_device_ops_lib.get_devices_from(devices) + return value_lib.Mirrored( + {d: v for d, v in zip(devices, [value] * len(devices))}) + + +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + +_cpu_device = "/device:CPU:0" + + +class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase): + + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.items(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): + self.assertEqual([v.numpy() for v in left._index.values()], + list(right._index.values())) + else: + with self.cached_session() as sess: + self.assertEqual( + sess.run(list(left._index.values())), list(right._index.values())) + + def _testReductionAndBroadcast(self, cross_device_ops, distribution): + devices = distribution.extended.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_replica = _make_per_replica(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_replica_2 = _make_per_replica(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + + all_destinations = [ + destination_mirrored, destination_different, destination_str, + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, + per_replica, + destinations=destinations), + _fake_mirrored(mean, destinations)) + self._assert_values_equal( + cross_device_ops.reduce( + reduce_util.ReduceOp.MEAN, + per_replica_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations)) + self._assert_values_equal( + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, per_replica, + destinations=destinations), + _fake_mirrored(mean * len(devices), destinations)) + self._assert_values_equal( + cross_device_ops.reduce( + reduce_util.ReduceOp.SUM, + per_replica_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.MEAN, + [(per_replica, d1), (per_replica_2, d2)]), + [ + _fake_mirrored(mean, d1), + _fake_mirrored(mean_2, d2) + ]) + self._assert_values_equal( + cross_device_ops.batch_reduce( + reduce_util.ReduceOp.SUM, + [(per_replica, d1), (per_replica_2, d2)]), + [ + _fake_mirrored(mean * len(devices), d1), + _fake_mirrored(mean_2 * len(devices), d2) + ]) + + # test broadcast() + for destinations in all_destinations: + self._assert_values_equal( + cross_device_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + +class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase): + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_device_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossDeviceOps", + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossDeviceOps", + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossDeviceOp", + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_device_ops=[ + combinations.NamedObject( + "AllReduce", + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_device_ops_lib.AllReduceCrossDeviceOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_device_ops_lib.AllReduceCrossDeviceOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_device_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_device_ops, distribution) + + def testChooseAlgorithm(self): + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) + + # if there are only 4 devices + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + # if devices links contain each device itself + device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], + [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], + [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) + + # if not dgx1-like links + device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], + [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] + result = cross_device_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertIsInstance(result, cross_device_ops_lib.AllReduceCrossDeviceOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + result = cross_device_ops_lib._simple_reduce( + per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate( + combinations.combine( + cross_device_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossDeviceOps", + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps()), + combinations.NamedObject( + "AllReduceCrossDeviceOps", + cross_device_ops_lib.AllReduceCrossDeviceOps()) + ], + reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op, + batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_replica = value_lib.PerReplica({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_device_ops_instance.batch_reduce( + reduce_op, [(per_replica, per_replica)]) + else: + result = cross_device_ops_instance.reduce( + reduce_op, per_replica, per_replica) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if reduce_op == reduce_util.ReduceOp.SUM: + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert reduce_op == reduce_util.ReduceOp.MEAN + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + + +class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossDeviceOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_device_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_device_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_device_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_device_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_device_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.NamedDistribution( + "MirroredCPU", + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0), + required_gpus=0), + combinations.NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1), + required_gpus=1), + combinations.NamedDistribution( + "Mirrored2GPUs", + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2), + required_gpus=2), + # pylint: disable=g-long-lambda + combinations.NamedDistribution( + "CoreMirroredCPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:CPU:0"]), + required_gpus=0), + combinations.NamedDistribution( + "CoreMirrored1GPU", + lambda: mirrored_strategy.CoreMirroredStrategy(["/device:GPU:0"]), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored2GPUs", + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]), + required_gpus=2), + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_device_ops, distribution): + distribution.configure(cluster_spec={ + "worker": + ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] + }) + with distribution.scope(): + self._testReductionAndBroadcast(cross_device_ops, distribution) + + +class MultiWorkerCollectiveAllReduceTest( + multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): + + collective_key_base = 100000 + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + + def setUp(self): + super(MultiWorkerCollectiveAllReduceTest, self).setUp() + # Reusing keys are not supported well. So we have to give a different + # collective key base for different tests. + MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 + + def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): + collective_keys = cross_device_utils.CollectiveKeys( + group_key_start=10 * num_gpus + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_start=num_gpus * 100 + + MultiWorkerCollectiveAllReduceTest.collective_key_base, + instance_key_with_id_start=num_gpus * 10000 + + MultiWorkerCollectiveAllReduceTest.collective_key_base) + if local_mode: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + 1, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = ["/device:GPU:%d" % i for i in range(num_gpus)] + else: + devices = ["/device:CPU:0"] + return collective_all_reduce_ops, devices, "" + else: + collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce( + 3, num_gpus, collective_keys=collective_keys) + if num_gpus: + devices = [ + "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) + for i in range(num_gpus) + ] + else: + devices = ["/job:%s/task:%d" % (task_type, task_id)] + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) + + def _assert_values_equal(self, left, right, sess): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_values_equal(l, r, sess) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(set(left.devices), set(right.devices)) + + run_options = config_pb2.RunOptions() + run_options.experimental.collective_graph_key = 6 + + left_values = np.array( + sess.run(list(left._index.values()), options=run_options)).flatten() + right_values = np.array(list(right._index.values())).flatten() + self.assertEqual(len(left_values), len(right_values)) + for l, r in zip(left_values, right_values): + self.assertEqual(l, r) + + def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): + collective_all_reduce, devices, master_target = self._get_test_objects( + task_type, task_id, num_gpus, local_mode=local_mode) + if local_mode: + num_workers = 1 + worker_device = None + else: + num_workers = len(self._cluster_spec.get("chief", [])) + len( + self._cluster_spec.get("worker", [])) + worker_device = "/job:%s/task:%d" % (task_type, task_id) + with ops.Graph().as_default(), \ + ops.device(worker_device), \ + self.cached_session(target=master_target) as sess: + # Collective ops doesn't support scalar tensors, so we have to construct + # 1-d tensors. + values = [constant_op.constant([float(d)]) for d in range(len(devices))] + per_replica = _make_per_replica(values, devices, regroup=True) + mean = np.array([(len(devices) - 1.) / 2.]) + + values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] + per_replica_2 = _make_per_replica(values_2, devices) + mean_2 = np.array([mean[0] + 1.]) + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + + all_destinations = [ + destination_different, destination_mirrored, destination_str + ] + + # test reduce() + for destinations in all_destinations: + self._assert_values_equal( + collective_all_reduce.reduce( + reduce_util.ReduceOp.MEAN, + per_replica, + destinations=destinations), + _fake_mirrored(mean, destinations), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + reduce_util.ReduceOp.MEAN, + per_replica_2, + destinations=destinations), + _fake_mirrored(mean_2, destinations), sess) + self._assert_values_equal( + collective_all_reduce.reduce( + reduce_util.ReduceOp.SUM, + per_replica, + destinations=destinations), + _fake_mirrored(mean * len(devices) * num_workers, destinations), + sess) + self._assert_values_equal( + collective_all_reduce.reduce( + reduce_util.ReduceOp.SUM, + per_replica_2, + destinations=destinations), + _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), + sess) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_values_equal( + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.MEAN, + [(per_replica, d1), + (per_replica_2, d2)]), + [ + _fake_mirrored(mean, d1), + _fake_mirrored(mean_2, d2) + ], sess) + self._assert_values_equal( + collective_all_reduce.batch_reduce(reduce_util.ReduceOp.SUM, + [(per_replica, d1), + (per_replica_2, d2)]), + [ + _fake_mirrored(mean * len(devices) * num_workers, d1), + _fake_mirrored(mean_2 * len(devices) * num_workers, d2) + ], sess) + + return True + + @combinations.generate( + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) + def testReductionDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients(self._test_reduction, self._cluster_spec, + num_gpus) + + # Collective ops doesn't support strategy with one device. + def testReductionLocal(self, num_gpus=2): + if context.num_gpus() < num_gpus: + return + self._test_reduction(None, None, num_gpus, local_mode=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_device_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2303a31677afbd12a0b8e7eea3ecf7c7736c46ad --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_device_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_device_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_device_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_device_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_device_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_device_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes + def testContainsIndexedSlices_PerReplica(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_replica = value_lib.PerReplica({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_device_utils.contains_indexed_slices(per_replica)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_device_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py deleted file mode 100644 index e08ba9c2a668cd675defb025d7ad060e1338506b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ /dev/null @@ -1,959 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes for different algorithms of reduction and broadcasting.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import six - -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import values as value_lib -from tensorflow.python.client import device_lib -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import device_util - - -def check_destinations(destinations): - """Checks whether `destinations` is not empty. - - Args: - destinations: a DistributedValues, Variable, string or a list of strings. - - Returns: - Boolean which is True if `destinations` is not empty. - """ - # Calling bool() on a ResourceVariable is not allowed. - if isinstance(destinations, resource_variable_ops.ResourceVariable): - return bool(destinations.device) - return bool(destinations) - - -def validate_destinations(destinations): - if not isinstance( - destinations, - (value_lib.DistributedValues, resource_variable_ops.ResourceVariable, - value_lib.AggregatingVariable, six.string_types, list)): - raise ValueError("destinations must be one of a `DistributedValues` object," - " a tf.Variable object, a device string, a list of device " - "strings") - - if not check_destinations(destinations): - raise ValueError("destinations can not be empty") - - -def _make_tensor_into_per_device(input_tensor): - """Converts a single tensor into a PerDevice object.""" - if isinstance(input_tensor, (tuple, list)): - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, " - "got %r but expected a object that is not a tuple or list." - % (input_tensor,)) - if isinstance(input_tensor, value_lib.PerDevice): - return input_tensor - - try: - device = input_tensor.device - except AttributeError: - raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object " - "because it doesn't have device set.") - - return value_lib.PerDevice({device: input_tensor}) - - -def _normalize_value_destination_pairs(value_destination_pairs): - """Converts each tensor into a PerDevice object in the input list.""" - result = [] - if not isinstance(value_destination_pairs, (list, tuple)): - raise ValueError("`value_destination_pairs` should be a list or tuple") - for pair in value_destination_pairs: - if not isinstance(pair, tuple): - raise ValueError( - "Each element of `value_destination_pairs` should be a tuple.") - if len(pair) != 2: - raise ValueError("Each element of `value_destination_pairs` should be a " - "tuple of size 2.") - - per_device = _make_tensor_into_per_device(pair[0]) - result.append((per_device, pair[1])) - return result - - -def _validate_value_destination_pairs(value_destination_pairs): - # TODO(yuefengz): raise exceptions instead of returning False. - # pylint: disable=g-missing-docstring - if not value_destination_pairs: return False - if not isinstance(value_destination_pairs, (list, tuple)): return False - if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): - return False - if not all([isinstance(v[0], value_lib.PerDevice) - for v in value_destination_pairs]): - return False - return True - - -# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. -def get_devices_from(destinations): - if isinstance(destinations, value_lib.DistributedValues): - return list(destinations.devices) - elif isinstance(destinations, (resource_variable_ops.ResourceVariable, - value_lib.AggregatingVariable)): - return [destinations.device] - elif isinstance(destinations, six.string_types): - return [device_util.resolve(destinations)] - elif isinstance(destinations, (list, tuple)): - return [device_util.resolve(destination) for destination in destinations] - else: - return [destinations.device] - - -def _devices_match(left, right): - return set(get_devices_from(left)) == set(get_devices_from(right)) - - -def _all_devices_match(value_destination_pairs): - if not all([_devices_match(v, d) for v, d in value_destination_pairs]): - return False - if not all([_devices_match(v, value_destination_pairs[0][0]) - for v, _ in value_destination_pairs[1:]]): - return False - return True - - -def _simple_broadcast(value, destinations): - index = {} - devices = get_devices_from(destinations) - for d in devices: - index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( - value, d) - return value_lib.Mirrored(index) - - -def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, - aggregation): - # pylint: disable=g-missing-docstring - all_values = [] - count = 0 - for v in per_device_value._index.values(): # pylint: disable=protected-access - if isinstance(v, value_lib.MapOutput): - v_list = v.get() - if not v_list: - continue - count += len(v_list) - # Sum within each device before aggregating across devices. - # TODO(yuefengz): Check whether it helps to use accumulation_fn here. - v = cross_tower_utils.aggregate_tensors_or_indexed_slices( - v_list, math_ops.add_n) - else: - count += 1 - all_values.append(v) - if not all_values: - raise ValueError("`per_device_value` must be non-empty") - - with ops.device(reduce_to_device): - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( - all_values, accumulation_fn) - if aggregation == vs.VariableAggregation.MEAN: - reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( - reduced, count) - elif aggregation != vs.VariableAggregation.SUM: - raise ValueError("`aggregation` must be VariableAggregation.SUM " - "or VariableAggregation.MEAN.") - return reduced - - -class CrossTowerOps(object): - """Base class for cross-tower reduction and broadcasting algorithms.""" - - def __init__(self): - pass - - def reduce(self, aggregation, per_device_value, destinations): - """Reduce `per_device_value` to `destinations`. - - It runs the reduction operation defined by `aggregation` and put the - result on `destinations`. - - Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - per_device_value: a PerDevice object or a tensor with device set. - destinations: the reduction destinations. - - Returns: - a Mirrored object. - - Raises: - ValueError: if per_device_value is not a PerDevice object. - """ - if not isinstance(per_device_value, value_lib.PerDevice): - per_device_value = _make_tensor_into_per_device(per_device_value) - - validate_destinations(destinations) - return self._reduce(aggregation, per_device_value, destinations) - - def batch_reduce(self, aggregation, value_destination_pairs): - """Reduce PerDevice objects in a batch. - - Reduce each first element in `value_destination_pairs` to each second - element which indicates the destinations. - - Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - value_destination_pairs: a list or a tuple of tuples of PerDevice objects - (or tensors with device set if there is one tower) and destinations. - - Returns: - a list of Mirrored objects. - - Raises: - ValueError: if `value_destination_pairs` is not a list or a tuple of - tuples of PerDevice objects and destinations - """ - if not _validate_value_destination_pairs(value_destination_pairs): - # If the first element of each pair is a tensor, we try to turn it into a - # PerDevice object. - value_destination_pairs = _normalize_value_destination_pairs( - value_destination_pairs) - - for _, d in value_destination_pairs: - validate_destinations(d) - - return self._batch_reduce(aggregation, value_destination_pairs) - - def broadcast(self, tensor, destinations): - """Broadcast the `tensor` to destinations. - - Args: - tensor: the tensor to broadcast. - destinations: the broadcast destinations. - - Returns: - a Mirrored object. - """ - validate_destinations(destinations) - return self._broadcast(tensor, destinations) - - def _reduce(self, aggregation, per_device_value, destinations): - raise NotImplementedError( - "_reduce method must be implemented in descendants.") - - def _batch_reduce(self, aggregation, value_destination_pairs): - raise NotImplementedError( - "_batch_reduce method must be implemented in descendants.") - - def _broadcast(self, tensor, destinations): - return _simple_broadcast(tensor, destinations) - - -class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): - """Always do reduction to one device first and then do broadcasting. - - Batch reduction is done by reduction on each element one by one. - """ - - def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n): - """Constructor. - - Args: - reduce_to_device: the intermediate device to reduce to. If None, reduce - to the first device in `destinations` of the reduce() method. - accumulation_fn: a function that does accumulation. - """ - self.reduce_to_device = reduce_to_device - self.accumulation_fn = accumulation_fn - super(ReductionToOneDeviceCrossTowerOps, self).__init__() - - def _reduce(self, aggregation, per_device_value, destinations): - if check_destinations(destinations): - devices = get_devices_from(destinations) - else: - devices = get_devices_from(per_device_value) - reduce_to_device = self.reduce_to_device or devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, - self.accumulation_fn, aggregation) - return self.broadcast(reduced, devices) - - def _batch_reduce(self, aggregation, value_destination_pairs): - return [ - self._reduce(aggregation, t, destinations=v) - for t, v in value_destination_pairs - ] - - -def _group_value_by_device(per_device_values): - """Group values into sublists by their devices. - - This grouping is needed to call the all-reduce library because it expects a - list of the following form: - [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], - [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], - [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], - ... - ] - - Args: - per_device_values: a list of PerDevice obejcts. - - Returns: - a list of lists, each sublist has components for its corresponding device of - PerDevice objects, paired with a None. - """ - destinations = per_device_values[0].devices - grouped = [[] for _ in range(len(destinations))] - for per_device_value in per_device_values: - # pylint: disable=protected-access - for i, v in enumerate(per_device_value._index.values()): - assert per_device_value.devices == destinations - grouped[i].append((v, None)) - return grouped - - -def _ungroup_and_make_mirrored(grouped_reduced, - destinations, - aggregation, - num_between_graph_workers=1): - """Ungroup results from all-reduce and make Mirrored objects. - - Each all-reduce result will be divided by the number of destinations before - Mirrored objects are created if aggregation is "mean". - - Args: - grouped_reduced: a list of lists, each sublist has components for each - device, paired with a None. It is the result from - cross_tower_utils.aggregate_gradients_using*. - destinations: a list of device strings for returned Mirrored objects. - aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - num_between_graph_workers: number of workers in the between-graph - replication. - - Returns: - a list of Mirrored objects. - """ - index = [{} for _ in range(len(grouped_reduced[0]))] - for d, per_device_reduced in enumerate(grouped_reduced): - for i, (v, _) in enumerate(per_device_reduced): - if aggregation == vs.VariableAggregation.MEAN: - index[i][destinations[d]] = v / ( - len(destinations) * num_between_graph_workers) - else: - index[i][destinations[d]] = v - return [value_lib.Mirrored(v) for v in index] - - -class ConcatAndSplitPacker(object): - """Concatenate and split tensors for reduction.""" - - def __init__(self, num_packs=1): - """Initialize the ConcatAndSplitPacker object. - - Args: - num_packs: specifies the number of split packs that will be - formed. - - Raises: - ValueError: if num_packs is not greater than 0. - """ - if num_packs <= 0: - raise ValueError("num_packs must be greater than zero.") - self.num_packs = num_packs - - def pack(self, grouped_grads_and_vars): - """Pack tensors.""" - self.grouped_grads_and_vars = grouped_grads_and_vars - self.all_tower_shapes = [] - self.all_tower_sizes = [] - - device_grad_packs = [] - for tower_grads_and_vars in grouped_grads_and_vars: - with ops.colocate_with(tower_grads_and_vars[0][0]): - # Flatten all the grads. - flat_grads = [ - array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars - ] - # Remember the original shape of all the grads. - tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars] - # Remember the original sizes of all the grads. - tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars] - # Concat all the flat grads into a big flat tensor. - concat_grads = array_ops.concat(flat_grads, 0) - - # Split the big tensor into num_splits packs. In cases where the - # total size is not divisible num_splits, the last pack gets - # more elements. - # TODO(zhengxq): it is also possible to optimize away all the concat - # as well. - num_splits = self.num_packs - - # The array_ops.size function will sometimes remove static shapes. So if - # all gradient shapes are defined, we use another method to get the - # total size. - # TODO(yuefengz): move this logic to array_ops.size. - if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): - total_grad_size = sum( - [g.shape.num_elements() for g, _ in tower_grads_and_vars]) - else: - total_grad_size = array_ops.size(concat_grads) - - split_size = total_grad_size // num_splits - split_size_last = total_grad_size - split_size * (num_splits - 1) - split_sizes = [split_size] * (num_splits - 1) + [split_size_last] - grad_packs = array_ops.split(concat_grads, split_sizes) - - # Ready to aggregate the repacked gradients, with fake variables. - # TODO(zhengxq): It is hacky to have to use fake variables. - # We should remove the need for variables in - # aggregate_gradients_using*. - device_grad_packs.append(zip(grad_packs, [None] * num_splits)) - self.all_tower_shapes.append(tower_shapes) - self.all_tower_sizes.append(tower_sizes) - - return device_grad_packs - - def unpack(self, summed_device_grad_packs): - """Reverse the pack.""" - aggregated_device_grads = [] - for (summed_tower_grad_packs, - tower_grads_and_vars, tower_shapes, tower_sizes) in zip( - summed_device_grad_packs, self.grouped_grads_and_vars, - self.all_tower_shapes, self.all_tower_sizes): - # pylint: enable=line-too-long - # Reverse the packing operations in the previous steps. Form the - # summed gradients back into their original shapes. - with ops.colocate_with(summed_tower_grad_packs[0][0]): - # Form a list of the summed grad packs. - device_grad_packs = [g for g, _ in summed_tower_grad_packs] - - # Concat them back into a big flat tensor. - device_grads_concat = array_ops.concat(device_grad_packs, 0) - - # Split the tensors back into their original sizes. - grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes) - - # Reshape the tensors back into their original shapes. - grads_with_shapes = [ - array_ops.reshape(grad, shape) - for shape, grad in zip(tower_shapes, grads_with_sizes) - ] - - # Form the list with the original list of variables. - summed_tower_grads = [ - (g, v) for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars) - ] - aggregated_device_grads.append(summed_tower_grads) - return aggregated_device_grads - - -class AggregateSmallTensorPacker(object): - """Concatenate small gradient tensors together for reduction.""" - - def __init__(self, - agg_small_grads_max_bytes=1048576, - agg_small_grads_max_group=16): - """Initialize the AggregateSmallTensorPacker object. - - Args: - agg_small_grads_max_bytes: largest tensor eligible for aggregation, - in number of bytes. - agg_small_grads_max_group: largest permitted aggregation of small - tensors. - - Raises: - ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group` - is not greater than 0. - """ - if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0: - raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group" - " should both be greater than zero.") - self.agg_small_grads_max_bytes = agg_small_grads_max_bytes - self.agg_small_grads_max_group = agg_small_grads_max_group - - def pack(self, grouped_grads_and_vars): - """Aggregate small tensors.""" - if (self.agg_small_grads_max_bytes > 0 and - self.agg_small_grads_max_group > 0): - tower_grads, self.packing = cross_tower_utils.pack_small_tensors( - grouped_grads_and_vars, - max_bytes=self.agg_small_grads_max_bytes, - max_group=self.agg_small_grads_max_group) - return tower_grads - - def unpack(self, summed_device_grad_packs): - """Reverse the aggregation process.""" - return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs, - self.packing) - - -def _pack_tensors(device_grads, - num_packs=0, - agg_small_grads_max_bytes=0, - agg_small_grads_max_group=0): - """Pack tensors if specified.""" - if num_packs > 0: - tensor_packer = ConcatAndSplitPacker(num_packs) - device_grad_packs = tensor_packer.pack(device_grads) - elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: - tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, - agg_small_grads_max_group) - device_grad_packs = tensor_packer.pack(device_grads) - else: - tensor_packer = None - device_grad_packs = device_grads - return device_grad_packs, tensor_packer - - -def _unpack_tensors(reduced, tensor_packer=None): - """Unpack tensors if they are packed before all-reduce.""" - if tensor_packer: - return tensor_packer.unpack(reduced) - return reduced - - -class AllReduceCrossTowerOps(CrossTowerOps): - """Reduction using all reduce.""" - - def __init__(self, - all_reduce_alg="nccl", - num_packs=1, - agg_small_grads_max_bytes=0, - agg_small_grads_max_group=10): - """All-reduce implementation of CrossTowerOps. - - Before performing all-reduce, tensors will be repacked or aggregated for - more efficient cross-device transportation: - 1) If `num_packs` is non-zero, pack values into - `num_packs` splits. - 2) Otherwise, if `agg_small_grads_max_bytes` > 0 and - `agg_small_grads_max_group` > 0, aggregate values smaller than - `agg_small_grads_max_bytes` into groups with at most - `agg_small_grads_max_group` values. - 3) Otherwise, no repacking or grouping will happen. - - Args: - all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or - "hierarchical_copy" are supported. - num_packs: see above. - agg_small_grads_max_bytes: see above. - agg_small_grads_max_group: see above. - tensors. - """ - self._all_reduce_alg = all_reduce_alg - self._num_packs = num_packs - self._agg_small_grads_max_bytes = agg_small_grads_max_bytes - self._agg_small_grads_max_group = agg_small_grads_max_group - super(AllReduceCrossTowerOps, self).__init__() - - def _reduce(self, aggregation, per_device_value, destinations): - contains_indexed_slices = cross_tower_utils.contains_indexed_slices( - per_device_value) - if (_devices_match(per_device_value, destinations) - and not context.executing_eagerly() - and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, [per_device_value])[0] - else: - if contains_indexed_slices: - logging.log_first_n( - logging.WARN, - "Efficient allreduce is not supported for IndexedSlices.", 10) - - if check_destinations(destinations): - devices = get_devices_from(destinations) - else: - devices = get_devices_from(per_device_value) - reduce_to_device = devices[0] - reduced = _simple_reduce(per_device_value, reduce_to_device, - math_ops.add_n, aggregation) - return self.broadcast(reduced, devices) - - def _batch_reduce(self, aggregation, value_destination_pairs): - all_devices_match = _all_devices_match(value_destination_pairs) - contains_indexed_slices = cross_tower_utils.contains_indexed_slices( - value_destination_pairs) - if (all_devices_match and not context.executing_eagerly() - and not contains_indexed_slices): - return self._batch_all_reduce(aggregation, - [v[0] for v in value_destination_pairs]) - else: - if not all_devices_match: - logging.log_first_n(logging.WARN, - "Efficient batch_reduce is not supported if " - "destinations are different.", - 10) - - return [ - self._reduce(aggregation, t, destinations=v) - for t, v in value_destination_pairs - ] - - def _batch_all_reduce(self, aggregation, per_device_values): - """All reduce algorithm in a batch.""" - logging.log_first_n( - logging.INFO, "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_alg, self._num_packs, - self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - destinations = per_device_values[0].devices - grouped = _group_value_by_device(per_device_values) - - device_grad_packs, tensor_packer = _pack_tensors( - grouped, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) - - # The actual aggregation of the repacked gradients. Note that they are - # sharded among different aggregation trees. So it is important to strike - # the balance on num_splits. - if self._all_reduce_alg == "nccl": - # TODO(yuefengz): merge this into the all-reduce library. - reduced = cross_tower_utils.aggregate_gradients_using_nccl( - device_grad_packs) - else: - # TODO(yuefengz): check that gpu ids in `destinations` are in ascending - # order. - reduced = ( - cross_tower_utils.aggregate_gradients_using_hierarchical_copy( - destinations, device_grad_packs)) - - reduced = _unpack_tensors(reduced, tensor_packer) - return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, - aggregation) - - -AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", - "alg shards limit") - - -class MultiWorkerAllReduce(AllReduceCrossTowerOps): - """All-reduce algorithms for distributed TensorFlow.""" - - def __init__(self, - worker_devices, - num_gpus_per_worker, - all_reduce_spec=("pscpu/pscpu", 2, -1), - num_packs=0, - agg_small_grads_max_bytes=0, - agg_small_grads_max_group=10): - """Initialize the all-reduce algorithm. - - Args: - worker_devices: a list of device strings for workers participating in - all-reduce. - num_gpus_per_worker: number of GPU devices per worker. - all_reduce_spec: a tuple or a named tuple or a list of tuples specifying - the all-reduce algorithm. - 1. The first element of a tuple is the name of the all-reduce algorithm. - Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", - "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with - a "/" are hierarchical, so two all-reduces are executed, the first one - aggregates tensors within a worker and the second aggregates across - workers. - 2. The second element of a tuple is the number of shards when doing - all-reduce. Let's say its values is M, each tensor after packing will be - split into M shards and then M parallel all-reduces would be performed - before finally they are concatenated backed into a complete tensor. - 3. The third element is the maximum size of tensors that will be - applicable for the algorithm specified by the first element. For - example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], - tensors with size not larger than 1024 bytes will be applied a 2-shard - "nccl" all-reduce and other tensors will be applied a 2-shard - "pscpu/pscpu" algorithm. The third elements should be in increasing - order across tuples and end with -1 which indicates infinity. - num_packs: see AllReduceCrossTowerOps. - agg_small_grads_max_bytes: see AllReduceCrossTowerOps. - agg_small_grads_max_group: see AllReduceCrossTowerOps. - """ - self._worker_devices = worker_devices - self._num_gpus_per_worker = num_gpus_per_worker - super(MultiWorkerAllReduce, self).__init__( - num_packs=num_packs, - agg_small_grads_max_bytes=agg_small_grads_max_bytes, - agg_small_grads_max_group=agg_small_grads_max_group) - - def validate_and_complete_spec(spec): - """Validate and complete the all-reduce spec.""" - # TODO(yuefengz): support namedtuple. - if not isinstance(spec, tuple): - raise ValueError( - "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) - if not spec or len(spec) > 3: - raise ValueError( - "Too many elements in the all-reduce spec tuple: %r" % spec) - if len(spec) == 1: - return AllReduceSpecTuple(spec[0], 1, -1) - elif len(spec) == 2: - return AllReduceSpecTuple(spec[0], spec[1], -1) - else: - return AllReduceSpecTuple(*spec) - - self._all_reduce_spec = [] - if isinstance(all_reduce_spec, six.string_types): - self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) - elif isinstance(all_reduce_spec, tuple): - self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) - elif isinstance(all_reduce_spec, list): - self._all_reduce_spec = [ - validate_and_complete_spec(spec) for spec in all_reduce_spec - ] - - def _batch_all_reduce(self, aggregation, per_device_values): - """All reduce algorithm in a batch.""" - logging.log_first_n( - logging.INFO, - "distributed batch_all_reduce invoked for batches size = %d with " - "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " - "and agg_small_grads_max_group = %d" % - (len(per_device_values), self._all_reduce_spec, self._num_packs, - self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10) - - destinations = sorted(per_device_values[0].devices) - device_grads = _group_value_by_device(per_device_values) - - # The all reduce library requires fully defined shapes. - # TODO(yuefengz): when tensor sharding is not needed, static shapes are not - # required as well. - for device_grad in device_grads: - for grad, _ in device_grad: - if not grad.shape.is_fully_defined(): - raise ValueError("Shape is unknown for node %r" % grad) - - remaining_grads = device_grads - aggregated_grads = [] - for spec_tuple in self._all_reduce_spec: - if spec_tuple.limit < 0: - this_grads = remaining_grads - remaining_grads = [] - else: - (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( - spec_tuple.limit, remaining_grads) - if this_grads: - device_grad_packs, tensor_packer = _pack_tensors( - this_grads, self._num_packs, self._agg_small_grads_max_bytes, - self._agg_small_grads_max_group) - range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( - self._worker_devices, device_grad_packs, len(self._worker_devices), - spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) - range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) - - if not aggregated_grads: - aggregated_grads = range_agg_grads - else: - assert len(aggregated_grads) == len(range_agg_grads) - for i in range(len(aggregated_grads)): - aggregated_grads[i] += range_agg_grads[i] - assert not remaining_grads - - return _ungroup_and_make_mirrored(aggregated_grads, destinations, - aggregation) - - -# TODO(yuefengz): support in-graph collective all-reduce. -class CollectiveAllReduce(CrossTowerOps): - """All-reduce cross tower ops using collective ops. - - In the between-graph replicated training, it will still do all-reduces across - all workers and then put results on the right destinations. - """ - - def __init__(self, - num_workers=1, - num_gpus_per_worker=0, - all_reduce_merge_scope=32, - collective_keys=None): - """Initializes the object. - - Args: - num_workers: number of workers in the between-graph replicated training. - num_gpus_per_worker: number of GPUs per worker. - all_reduce_merge_scope: size of groups into which to partition consecutive - gradients grouped under a common 'allreduce' name scope. This is useful - for some optimization of collective ops. - collective_keys: an optional CollectiveKey object. - """ - self._num_workers = num_workers - self._num_gpus_per_worker = num_gpus_per_worker - self._all_reduce_merge_scope = all_reduce_merge_scope - self._collective_keys = collective_keys or cross_tower_utils.CollectiveKeys( - ) - super(CollectiveAllReduce, self).__init__() - - # TODO(yuefengz, tucker): is indexed slices supported by collective ops? - def _reduce(self, aggregation, per_device_value, destinations): - if cross_tower_utils.contains_indexed_slices(per_device_value): - raise ValueError( - "`IndexSlices` is not supported for Collective All-Reduce.") - if context.executing_eagerly(): - raise ValueError( - "Eager execution is not supported for Collective All-Reduce") - - all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] - if _devices_match(per_device_value, destinations): - return all_reduced - else: - index = {} - for d in get_devices_from(destinations): - # pylint: disable=protected-access - if d in all_reduced._index: - index[d] = all_reduced._index[d] - else: - with ops.control_dependencies(list( - all_reduced._index.values())), ops.device(d): - index[d] = array_ops.identity(list(all_reduced._index.values())[0]) - - return value_lib.Mirrored(index) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if cross_tower_utils.contains_indexed_slices(value_destination_pairs): - raise ValueError( - "`IndexSlices` is not supported for Collective All-Reduce.") - if context.executing_eagerly(): - raise ValueError( - "Eager execution is not supported for Collective All-Reduce") - - all_devices_match = _all_devices_match(value_destination_pairs) - if all_devices_match: - return self._batch_all_reduce(aggregation, - [v[0] for v in value_destination_pairs]) - else: - if not all_devices_match: - logging.log_first_n( - logging.WARN, "Efficient batch_reduce is not supported if " - "destinations are different.", 10) - - return [ - self._reduce(aggregation, t, destinations=v) - for t, v in value_destination_pairs - ] - - def _batch_all_reduce(self, aggregation, per_device_values): - """All-reduce across all workers in a batch.""" - if context.executing_eagerly(): - raise ValueError( - "Eager execution with collective ops is not supported yet.") - - logging.log_first_n( - logging.INFO, "Collective All-reduce invoked with batches size = %d, " - "num_workers = %d" % (len(per_device_values), self._num_workers), 10) - - grouped_by_tower = _group_value_by_device(per_device_values) - - grouped_by_var = list(zip(*grouped_by_tower)) - # grouped_by_var is grouped by variables and takes the following format: - # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..), - # ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..), - # ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..), - # ... - # ] - chunked_gv = [ - grouped_by_var[x:x + self._all_reduce_merge_scope] - for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope) - ] - - reduced_gv_list = [] - for chunk in chunked_gv: - with ops.name_scope("allreduce"): - for grad_and_vars in chunk: - scaled_grads = [g for g, _ in grad_and_vars] - collective_reduced = cross_tower_utils.build_collective_reduce( - scaled_grads, self._num_workers, self._collective_keys, "Add", - "Id") - result = [] - for (_, v), g in zip(grad_and_vars, collective_reduced): - result.append([g, v]) - reduced_gv_list.append(result) - - new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] - return _ungroup_and_make_mirrored( - new_tower_grads, - per_device_values[0].devices, - aggregation, - num_between_graph_workers=self._num_workers) - - -_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], - [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] - - -def _has_dgx1_like_links(gpu_links): - if not gpu_links: - return False - # TODO(yuefengz): figure out the right topology for hierarchial copy if - # number of gpus are less than 8. - if len(gpu_links) < 8: - return False - for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)): - if (set(gpu_link) != set(dgx1_link) and - set(gpu_link) != set(dgx1_link + [i])): - return False - return True - - -def _choose_all_reduce_algorithm(device_links): - if _has_dgx1_like_links(device_links): - logging.info("Configured hierarchical_copy with num_packs=%d", - len(device_links)) - return AllReduceCrossTowerOps( - "hierarchical_copy", num_packs=len(device_links)) - else: - logging.info("Configured nccl all-reduce.") - return AllReduceCrossTowerOps("nccl", num_packs=1) - - -def choose_the_best(devices, session_config=None): - """Find the best subclass of CrossTowerOps given a tensorflow session. - - Args: - devices: a list of devices passed for distribute strategy. - session_config: a tensorflow session config or None. If None, it will make - deciesion based on all local devices. - - Returns: - a subclass of CrossTowerOps. - """ - requested_devices = set([device_util.canonicalize(d) for d in devices]) - machine_devices = device_lib.list_local_devices(session_config=session_config) - using_devices = [] - for d in machine_devices: - if device_util.canonicalize(d.name) in requested_devices: - using_devices.append(d) - else: - logging.info( - "Device is available but not used by distribute strategy: %s", d.name) - - if len(using_devices) != len(requested_devices): - logging.warning("Not all devices in distribute strategy are visible by " - "TensorFlow sessions.") - return ReductionToOneDeviceCrossTowerOps() - - if any([d.device_type.lower() != "gpu" for d in using_devices]): - logging.warning("Not all devices in DistributionStrategy are visible to " - "TensorFlow session.") - return ReductionToOneDeviceCrossTowerOps() - - device_links = [[] for _ in range(len(using_devices))] - for i, device in enumerate(using_devices): - for link in device.locality.links.link: - device_links[i].append(link.device_id) - - return _choose_all_reduce_algorithm(device_links) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py deleted file mode 100644 index 490371477a1b43551c4b4d8768c96d60e5f2c6d8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ /dev/null @@ -1,564 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for CrossTowerOps.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values as value_lib -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import device_util - - -def _make_per_device(values, devices, regroup=False): - devices = cross_tower_ops_lib.get_devices_from(devices) - assert len(values) == len(devices) - - # We simulate the result of regroup called on PerDevice which strips the - # PerDevice wrapper if it has only one value. - if len(values) == 1 and regroup: - with ops.device(devices[0]): - placed_v = array_ops.identity(values[0]) - return placed_v - - index = {} - for d, v in zip(devices, values): - with ops.device(d): - placed_v = array_ops.identity(v) - index[d] = placed_v - return value_lib.PerDevice(index) - - -# pylint: disable=g-doc-args,g-doc-return-or-yield -def _fake_mirrored(value, devices): - """Create a faked Mirrored object for testing. - - All components of the returned Mirrored have the same objects, which is not - true in reality. - """ - devices = cross_tower_ops_lib.get_devices_from(devices) - return value_lib.Mirrored( - {d: v for d, v in zip(devices, [value] * len(devices))}) - - -def _make_indexed_slices(values, indices, dense_shape, device): - with ops.device(device): - tensor = ops.IndexedSlices( - values=constant_op.constant(values), - indices=constant_op.constant(indices), - dense_shape=constant_op.constant(dense_shape)) - return tensor - - -def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): - return value_lib.Mirrored({ - d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices - }) - - -_cpu_device = "/device:CPU:0" - - -class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): - - def _assert_indexed_slices_equal(self, left, right): - self.assertIsInstance(left, ops.IndexedSlices) - self.assertIsInstance(right, ops.IndexedSlices) - self.assertEqual(device_util.resolve(left.device), - device_util.resolve(right.device)) - self.assertAllEqual( - self.evaluate(ops.convert_to_tensor(left)), - self.evaluate(ops.convert_to_tensor(right))) - - def _assert_values_equal(self, left, right): - if isinstance(left, list): - for l, r in zip(left, right): - self._assert_values_equal(l, r) - else: - self.assertEqual(type(left), type(right)) - self.assertEqual(set(left.devices), set(right.devices)) - if isinstance(list(left._index.values())[0], ops.IndexedSlices): - for (d, v) in left._index.items(): - self._assert_indexed_slices_equal(v, right._index[d]) - elif context.executing_eagerly(): - self.assertEqual([v.numpy() for v in left._index.values()], - list(right._index.values())) - else: - with self.test_session() as sess: - self.assertEqual( - sess.run(list(left._index.values())), list(right._index.values())) - - def _testReductionAndBroadcast(self, cross_tower_ops, distribution): - devices = distribution.worker_devices - - values = [constant_op.constant(float(d)) for d in range(len(devices))] - per_device = _make_per_device(values, devices) - mean = (len(devices) - 1.) / 2. - - values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) - mean_2 = mean + 1. - - destination_mirrored = _fake_mirrored(1., devices) - destination_different = _fake_mirrored(1., _cpu_device) - destination_str = _cpu_device - destination_list = devices - - all_destinations = [ - destination_mirrored, destination_different, destination_str, - destination_list - ] - - # test reduce() - for destinations in all_destinations: - self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, - per_device, - destinations=destinations), - _fake_mirrored(mean, destinations)) - self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.MEAN, - per_device_2, - destinations=destinations), - _fake_mirrored(mean_2, destinations)) - self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, per_device, - destinations=destinations), - _fake_mirrored(mean * len(devices), destinations)) - self._assert_values_equal( - cross_tower_ops.reduce( - vs.VariableAggregation.SUM, - per_device_2, - destinations=destinations), - _fake_mirrored(mean_2 * len(devices), destinations)) - - # test batch_reduce() - for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), (per_device_2, d2)]), - [ - _fake_mirrored(mean, d1), - _fake_mirrored(mean_2, d2) - ]) - self._assert_values_equal( - cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), (per_device_2, d2)]), - [ - _fake_mirrored(mean * len(devices), d1), - _fake_mirrored(mean_2 * len(devices), d2) - ]) - - # test broadcast() - for destinations in all_destinations: - self._assert_values_equal( - cross_tower_ops.broadcast(constant_op.constant(1.), destinations), - _fake_mirrored(1., destinations)) - - -class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): - # TODO(yuefengz): decouple the num_gpus check from distribution in - # combinations module so that we can pass in devices instead of a distribution - # strategy. - reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "DefaultReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "ReductionToCPUDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - reduce_to_device=_cpu_device)), - combinations.NamedObject( - "AccumulateNCrossTowerOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - accumulation_fn=math_ops.accumulate_n)), - ], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus - ], - mode=["graph", "eager"]) - allreduce_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "AllReduce", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 8, 0, 0)), - combinations.NamedObject( - "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), - combinations.NamedObject( - "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 0, 100, 10)) - ], - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=["graph", "eager"]) - - @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): - with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) - - def testChooseAlgorithm(self): - device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], - [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result._all_reduce_alg, "hierarchical_copy") - self.assertEqual(result._num_packs, 8) - - # if there are only 4 devices - device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result._all_reduce_alg, "nccl") - self.assertEqual(result._num_packs, 1) - - # if devices links contain each device itself - device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], - [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], - [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result._all_reduce_alg, "hierarchical_copy") - self.assertEqual(result._num_packs, 8) - - # if not dgx1-like links - device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], - [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] - result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result._all_reduce_alg, "nccl") - self.assertEqual(result._num_packs, 1) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - required_gpus=1)) - def testSimpleReduceWithIndexedSlices(self): - devices = ["/cpu:0", "/gpu:0"] - t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) - t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) - result = cross_tower_ops_lib._simple_reduce( - per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM) - - # Test that the result is semantically equal to both the concatenated - # IndexedSlices with and without duplicate indices. - total_with_dups = _make_indexed_slices( - [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) - total_without_dups = _make_indexed_slices( - [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) - self._assert_indexed_slices_equal(total_with_dups, result) - self._assert_indexed_slices_equal(total_without_dups, result) - - @combinations.generate( - combinations.combine( - cross_tower_ops_instance=[ - combinations.NamedObject( - "ReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "AllReduceCrossTowerOps", - cross_tower_ops_lib.AllReduceCrossTowerOps()) - ], - aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN], - batch_reduce=[True, False], - mode=["graph", "eager"], - required_gpus=1)) - def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation, - batch_reduce): - devices = ["/cpu:0", "/gpu:0"] - dense_shape = [5, 2] - t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) - t1 = _make_indexed_slices( - [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) - per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) - - if batch_reduce: - result = cross_tower_ops_instance.batch_reduce(aggregation, - [(per_device, devices)]) - else: - result = cross_tower_ops_instance.reduce(aggregation, per_device, devices) - - total_indices_with_dups = [1, 1, 3] - total_indices_without_dups = [1, 3] - - if aggregation == vs.VariableAggregation.SUM: - total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] - total_values_without_dups = [[4., 6.], [5., 6.]] - else: - assert aggregation == vs.VariableAggregation.MEAN - total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] - total_values_without_dups = [[2., 3.], [2.5, 3.]] - - total_mirrored_with_dups = _make_mirrored_indexed_slices( - devices, total_values_with_dups, total_indices_with_dups, dense_shape) - total_mirrored_without_dups = _make_mirrored_indexed_slices( - devices, total_values_without_dups, total_indices_without_dups, - dense_shape) - - # Test that the result is semantically equal to both the concatenated - # IndexedSlices, as well as when the duplicate indices are summed up. - if batch_reduce: - total_mirrored_with_dups = [total_mirrored_with_dups] - total_mirrored_without_dups = [total_mirrored_without_dups] - - self._assert_values_equal(total_mirrored_with_dups, result) - self._assert_values_equal(total_mirrored_without_dups, result) - - -class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, - CrossTowerOpsTestBase): - - worker_devices = [ - "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" - ] - multi_worker_allreduce_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "MultiWorkerAllReduce", - cross_tower_ops_lib.MultiWorkerAllReduce( - worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), - combinations.NamedObject( - "MultiWorkerAllReducePack", - cross_tower_ops_lib.MultiWorkerAllReduce( - worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), - combinations.NamedObject( - "MultiWorkerAllReduceAggregation", - cross_tower_ops_lib.MultiWorkerAllReduce( - worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), - combinations.NamedObject( - "MultiWorkerAllReduceMultipleSpecs", - cross_tower_ops_lib.MultiWorkerAllReduce( - worker_devices, 2, [("pscpu/pscpu", 2, 100), - ("xring", 2, -1)], 0, 0, 0)), - ], - distribution=[ - combinations.NamedDistribution( - "MirroredCPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), - required_gpus=0), - combinations.NamedDistribution( - "Mirrored1GPU", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), - required_gpus=1), - combinations.NamedDistribution( - "Mirrored2GPUs", - lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), - required_gpus=2), - ], - mode=["graph"]) - - @combinations.generate(multi_worker_allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): - distribution.configure(cluster_spec={ - "worker": - ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] - }) - with distribution.scope(): - self._testReductionAndBroadcast(cross_tower_ops, distribution) - - -class MultiWorkerCollectiveAllReduceTest( - multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): - - collective_key_base = 100000 - - @classmethod - def setUpClass(cls): - """Create a local cluster with 2 workers.""" - cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) - - def setUp(self): - super(MultiWorkerCollectiveAllReduceTest, self).setUp() - # Reusing keys are not supported well. So we have to give a different - # collective key base for different tests. - MultiWorkerCollectiveAllReduceTest.collective_key_base += 100000 - - def _get_test_objects(self, task_type, task_id, num_gpus=0, local_mode=False): - collective_keys = cross_tower_utils.CollectiveKeys( - group_key_start=10 * num_gpus + - MultiWorkerCollectiveAllReduceTest.collective_key_base, - instance_key_start=num_gpus * 100 + - MultiWorkerCollectiveAllReduceTest.collective_key_base, - instance_key_with_id_start=num_gpus * 10000 + - MultiWorkerCollectiveAllReduceTest.collective_key_base) - if local_mode: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( - 1, num_gpus, collective_keys=collective_keys) - if num_gpus: - devices = ["/device:GPU:%d" % i for i in range(num_gpus)] - else: - devices = ["/device:CPU:0"] - return collective_all_reduce_ops, devices, "" - else: - collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( - 3, num_gpus, collective_keys=collective_keys) - if num_gpus: - devices = [ - "/job:%s/task:%d/device:GPU:%d" % (task_type, task_id, i) - for i in range(num_gpus) - ] - else: - devices = ["/job:%s/task:%d" % (task_type, task_id)] - return (collective_all_reduce_ops, devices, - "grpc://" + self._cluster_spec[task_type][task_id]) - - def _assert_values_equal(self, left, right, sess): - if isinstance(left, list): - for l, r in zip(left, right): - self._assert_values_equal(l, r, sess) - else: - self.assertEqual(type(left), type(right)) - self.assertEqual(set(left.devices), set(right.devices)) - - run_options = config_pb2.RunOptions() - run_options.experimental.collective_graph_key = 6 - - left_values = np.array( - sess.run(list(left._index.values()), options=run_options)).flatten() - right_values = np.array(list(right._index.values())).flatten() - self.assertEqual(len(left_values), len(right_values)) - for l, r in zip(left_values, right_values): - self.assertEqual(l, r) - - def _test_reduction(self, task_type, task_id, num_gpus, local_mode=False): - collective_all_reduce, devices, master_target = self._get_test_objects( - task_type, task_id, num_gpus, local_mode=local_mode) - if local_mode: - num_workers = 1 - worker_device = None - else: - num_workers = len(self._cluster_spec.get("chief", [])) + len( - self._cluster_spec.get("worker", [])) - worker_device = "/job:%s/task:%d" % (task_type, task_id) - with ops.Graph().as_default(), \ - ops.device(worker_device), \ - self.test_session(target=master_target) as sess: - # Collective ops doesn't support scalar tensors, so we have to construct - # 1-d tensors. - values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_device = _make_per_device(values, devices, regroup=True) - mean = np.array([(len(devices) - 1.) / 2.]) - - values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] - per_device_2 = _make_per_device(values_2, devices) - mean_2 = np.array([mean[0] + 1.]) - - destination_mirrored = _fake_mirrored(1., devices) - destination_different = _fake_mirrored(1., _cpu_device) - destination_str = _cpu_device - destination_list = devices - - all_destinations = [ - destination_different, destination_mirrored, destination_str, - destination_list - ] - - # test reduce() - for destinations in all_destinations: - self._assert_values_equal( - collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, - per_device, - destinations=destinations), - _fake_mirrored(mean, destinations), sess) - self._assert_values_equal( - collective_all_reduce.reduce( - vs.VariableAggregation.MEAN, - per_device_2, - destinations=destinations), - _fake_mirrored(mean_2, destinations), sess) - self._assert_values_equal( - collective_all_reduce.reduce( - vs.VariableAggregation.SUM, - per_device, - destinations=destinations), - _fake_mirrored(mean * len(devices) * num_workers, destinations), - sess) - self._assert_values_equal( - collective_all_reduce.reduce( - vs.VariableAggregation.SUM, - per_device_2, - destinations=destinations), - _fake_mirrored(mean_2 * len(devices) * num_workers, destinations), - sess) - - # test batch_reduce() - for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.MEAN, - [(per_device, d1), - (per_device_2, d2)]), - [ - _fake_mirrored(mean, d1), - _fake_mirrored(mean_2, d2) - ], sess) - self._assert_values_equal( - collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM, - [(per_device, d1), - (per_device_2, d2)]), - [ - _fake_mirrored(mean * len(devices) * num_workers, d1), - _fake_mirrored(mean_2 * len(devices) * num_workers, d2) - ], sess) - - return True - - @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) - def testReductionDistributed(self, num_gpus): - if context.num_gpus() < num_gpus: - return - self._run_between_graph_clients(self._test_reduction, self._cluster_spec, - num_gpus) - - # Collective ops doesn't support strategy with one device. - def testReductionLocal(self, num_gpus=2): - if context.num_gpus() < num_gpus: - return - self._test_reduction(None, None, num_gpus, local_mode=True) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py deleted file mode 100644 index 9fc1b8895516f64a956accd9290e7bf42ccef330..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ /dev/null @@ -1,671 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for cross_tower_ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections as pycoll -import threading - -from tensorflow.contrib import nccl -from tensorflow.contrib.all_reduce.python import all_reduce -from tensorflow.contrib.distribute.python import values as value_lib -from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import collective_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops - - -def aggregate_gradients_using_nccl(tower_grads): - """Aggregate gradients using nccl allreduce.""" - agg_all_g_and_v = [] - for single_g_and_v in zip(*tower_grads): - single_grads = [g for g, _ in single_g_and_v] - agg_grads = nccl.all_sum(single_grads) - agg_all_g_and_v.append( - [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) - - agg_all_g_and_v = list(zip(*agg_all_g_and_v)) - - return agg_all_g_and_v - - -def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads): - """Aggregate gradients using hierarchical copies. - - Args: - avail_devices: available GPU devices. - tower_grads: List of lists of (gradient, variable) tuples. The outer list - is over towers. The inner list is over individual gradients. - - Returns: - The list of (aggregated_gradient, variable), where the gradient has been - summed across all towers and the variable is chosen from the first tower. - """ - # This only works for DGX-1 type of machine topology - # Device peer to peer matrix - # DMA: 0 1 2 3 4 5 6 7 - # 0: Y Y Y Y Y N N N - # 1: Y Y Y Y N Y N N - # 2: Y Y Y Y N N Y N - # 3: Y Y Y Y N N N Y - # 4: Y N N N Y Y Y Y - # 5: N Y N N Y Y Y Y - # 6: N N Y N Y Y Y Y - # 7: N N N Y Y Y Y Y - agg_grads = [] - num_devices = len(avail_devices) - # In the special case of DGX-1 machine topology, the two groups have equal - # size. - group_size = num_devices // 2 - for i, single_grads in enumerate(zip(*tower_grads)): - group_0_main_device = i % num_devices - group_1_main_device = (group_0_main_device + group_size) % num_devices - if group_0_main_device < group_size: - group_0_begin = 0 - group_1_begin = group_size - else: - group_0_begin = group_size - group_1_begin = 0 - - # Aggregate the first group. - group_0_device_grads = single_grads[group_0_begin: - group_0_begin + group_size] - with ops.device(avail_devices[group_0_main_device]): - group_0_agg_grads, _ = aggregate_single_gradient_using_copy( - group_0_device_grads, False, False) - - # Aggregate the second group. - group_1_device_grads = single_grads[group_1_begin: - group_1_begin + group_size] - with ops.device(avail_devices[group_1_main_device]): - group_1_agg_grads, _ = aggregate_single_gradient_using_copy( - group_1_device_grads, False, False) - - # Aggregate between the groups. - with ops.device(avail_devices[group_0_main_device]): - (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( - [group_0_agg_grads, group_1_agg_grads], False, False) - - # Broadcast the result back into the root of each group. - with ops.device(avail_devices[group_0_main_device]): - group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) - with ops.device(avail_devices[group_1_main_device]): - group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) - - agg_grads_bcast = [] - for j in range(len(single_grads)): - with ops.device(avail_devices[j]): - # Broadcast the result back to each member in the group from the root. - if (group_0_main_device < group_size) == (j < group_size): - src_device_grad = group_0_agg_grads_bcast - else: - src_device_grad = group_1_agg_grads_bcast - agg_grads_bcast.append(array_ops.identity(src_device_grad)) - - agg_grads.append( - [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) - - agg_grads = list(zip(*agg_grads)) - - return agg_grads - - -def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, - check_inf_nan): - """Calculate the average gradient for a shared variable across all towers. - - Note that this function provides a synchronization point across all towers. - - Args: - grad_and_vars: A list or tuple of (gradient, variable) tuples. Each - (gradient, variable) pair within the outer list represents the gradient - of the variable calculated for a single tower, and the number of pairs - equals the number of towers. - use_mean: if True, mean is taken, else sum of gradients is taken. - check_inf_nan: check grads for nans and infs. - - Returns: - The tuple ([(average_gradient, variable),], has_nan_or_inf) where the - gradient has been averaged across all towers. The variable is chosen from - the first tower. The has_nan_or_inf indicates the grads has nan or inf. - """ - grads = [g for g, _ in grad_and_vars] - grad = math_ops.add_n(grads) - - if use_mean and len(grads) > 1: - grad = array_ops.multiply(grad, 1.0 / len(grads)) - - v = grad_and_vars[0][1] - if check_inf_nan: - has_nan_or_inf = array_ops.logical_not( - array_ops.reduce_all(array_ops.is_finite(grads))) - return (grad, v), has_nan_or_inf - else: - return (grad, v), None - - -def group_device_names(devices, group_size): - """Group device names into groups of group_size. - - Args: - devices: a list of canonical device strings. - group_size: integer which is equal to or greater than 1. - - Returns: - list of lists of devices, where each inner list is group_size long, - and each device appears at least once in an inner list. If - len(devices) % group_size == 0 then each device will appear exactly once. - - Raises: - ValueError: if group_size > len(devices) - """ - num_devices = len(devices) - if group_size > num_devices: - raise ValueError( - 'only %d devices, but group_size=%d' % (num_devices, group_size)) - num_groups = ( - num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) - groups = [[] for i in range(num_groups)] - for i in range(num_groups * group_size): - groups[i % num_groups].append(devices[i % num_devices]) - return groups - - -def split_grads_by_size(threshold_size, device_grads): - """Break gradients into two sets according to tensor size. - - Args: - threshold_size: int size cutoff for small vs large tensor. - device_grads: List of lists of (gradient, variable) tuples. The outer - list is over devices. The inner list is over individual gradients. - - Returns: - small_grads: Subset of device_grads where shape is <= threshold_size - elements. - large_grads: Subset of device_grads where shape is > threshold_size - elements. - """ - small_grads = [] - large_grads = [] - for dl in device_grads: - small_dl = [] - large_dl = [] - for (g, v) in dl: - tensor_size = g.get_shape().num_elements() - if tensor_size <= threshold_size: - small_dl.append([g, v]) - else: - large_dl.append([g, v]) - if small_dl: - small_grads.append(small_dl) - if large_dl: - large_grads.append(large_dl) - return small_grads, large_grads - - -# threading.Lock() and threading.local() cannot be pickled and therefore cannot -# be a field of CollectiveKeys. Right now _thread_local is not necessary to be -# an instance member of CollectiveKeys since we always create a new thread for -# each tower. -_lock = threading.Lock() -_thread_local = threading.local() - - -# TODO(yuefengz): use random key starts to avoid reusing keys? -class CollectiveKeys(object): - """Class that manages collective keys. - - We need to manage three different keys for collective: - - *Group key*: an integer key to identify the set of cooperative devices. - Collective ops work under the same set of devices must using the same group - key. - - *Instance key*: an integer key to identify the set of same counterpart of - tensors on different devices in a device group that need to be all-reduced. - - "Graph key": an integer key that is unique key graph. This is used to support - multiple graphs per client session. It must be non-zero and set in the - `config` argument of each call to `session.run`. - """ - - def __init__(self, - group_key_start=1, - instance_key_start=100, - instance_key_with_id_start=10000): - """Initializes the object. - - Args: - group_key_start: the starting integer of group key. - instance_key_start: the starting integer of instance key. - instance_key_with_id_start: the starting integer of instance key that is - recorded with an id. - """ - self._group_key = group_key_start - self._group_key_table = dict() - - # For instance keys with ids - self._instance_key_id_to_key_table = dict() - self._instance_key_with_id_counter = instance_key_with_id_start - - # For instance keys without ids - self._instance_key_start = instance_key_start - - def _get_thread_local_object(self): - # We make instance key without key ids thread local so that it will work - # with MirroredStrategy and distribute coordinator. - if not hasattr(_thread_local, 'instance_key'): - _thread_local.instance_key = self._instance_key_start - return _thread_local - - def get_group_key(self, devices): - """Returns a group key for the set of devices. - - Args: - devices: list of strings naming devices in a collective group. - - Returns: - int key uniquely identifying the set of device names. - """ - parsed = [pydev.DeviceSpec.from_string(d) for d in devices] - # In the between-graph replicated training, different workers need to get - # the same device key. So we remove the task_type and task_id from the - # devices. - # TODO(yuefengz): in the in-graph replicated training, we need to include - # task_type and task_id. - names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed]) - key_id = ','.join(names) - with _lock: - if key_id not in self._group_key_table: - new_key = self._group_key - self._group_key += 1 - self._group_key_table[key_id] = new_key - return self._group_key_table[key_id] - - def get_instance_key(self, key_id=None): - """Returns a new instance key for use in defining a collective op. - - Args: - key_id: optional string. If set, key will be recorded and the same key - will be returned when the same key_id is provided. If not, an increasing - instance key will be returned. - """ - if key_id: - with _lock: - if key_id not in self._instance_key_id_to_key_table: - self._instance_key_with_id_counter += 1 - self._instance_key_id_to_key_table[key_id] = ( - self._instance_key_with_id_counter) - return self._instance_key_id_to_key_table[key_id] - else: - v = self._get_thread_local_object().instance_key - self._get_thread_local_object().instance_key += 1 - return v - - -def build_collective_reduce(input_tensors, - num_workers, - collective_keys, - reduction_op='Add', - unary_op='Id'): - """Build a subgraph that does one full all-reduce, using the collective Op. - - Args: - input_tensors: tensors within a single worker graph that are to be reduced - together; must be one per device. - num_workers: total number of workers with identical independent graphs that - will be doing this same reduction. The reduction will actually include - the corresponding tensors at all these workers. - collective_keys: a CollectiveKeys object. - reduction_op: string naming the reduction op. - unary_op: string naming the unary final op. - - Returns: - An array of final tensors, one per device, computed by the full reduction. - - Raises: - ValueError: There must be at least two tensors over all the workers. - """ - group_size = len(input_tensors) * num_workers - if group_size < 2: - raise ValueError('num_workers * len(input_tensors) must be 2 or greater') - devices = [t.device for t in input_tensors] - num_devices = len(devices) - group_key = collective_keys.get_group_key(devices) - instance_key = collective_keys.get_instance_key() - out_tensors = [] - subdiv_offsets = [0] # TODO(tucker): maybe support non-default subdiv spec - for d in range(num_devices): - with ops.device(devices[d]): - reduce_op = collective_ops.all_reduce( - input_tensors[d], group_size, group_key, instance_key, reduction_op, - unary_op, subdiv_offsets) - out_tensors.append(reduce_op) - return out_tensors - - -def sum_grad_and_var_all_reduce(grad_and_vars, - num_workers, - alg, - gpu_indices, - aux_devices=None, - num_shards=1): - """Apply all-reduce algorithm over specified gradient tensors.""" - with ops.name_scope('allreduce'): - # Note that each grad_and_vars looks like the following: - # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) - scaled_grads = [g for g, _ in grad_and_vars] - if alg == 'nccl': - summed_grads = nccl.all_sum(scaled_grads) - elif alg == 'xring': - summed_grads = all_reduce.build_ring_all_reduce( - scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) - elif alg == 'nccl/xring': - summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, - math_ops.add) - elif alg == 'nccl/rechd': - summed_grads = all_reduce.build_nccl_then_recursive_hd( - scaled_grads, math_ops.add) - elif alg == 'nccl/pscpu': - summed_grads = all_reduce.build_nccl_then_shuffle( - scaled_grads, aux_devices, math_ops.add, math_ops.add_n) - elif alg == 'pscpu/pscpu': - second_gather_devices = aux_devices[:num_shards] - summed_grads = all_reduce.build_shuffle_then_shuffle( - scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) - elif alg in ['pscpu', 'psgpu']: - summed_grads = all_reduce.build_shuffle_all_reduce( - scaled_grads, aux_devices, math_ops.add_n) - else: - raise ValueError('unsupported all_reduce alg: ', alg) - - result = [] - for (_, v), g in zip(grad_and_vars, summed_grads): - result.append([g, v]) - return result - - -def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, - num_shards, gpu_indices): - """Apply all-reduce algorithm over specified gradient tensors. - - Args: - dev_prefixes: list of prefix strings to use to generate PS device names. - tower_grads: the gradients to reduce. - num_workers: number of worker processes across entire job. - alg: the all-reduce algorithm to apply. - num_shards: alg-specific sharding factor. - gpu_indices: indices of local GPUs in order usable for ring-reduce. - - Returns: - list of reduced tensors - """ - alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) - is_hierarchical = '/' in alg - if 'pscpu' in alg: - aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] - elif 'psgpu' in alg: - aux_devices = [ - prefix + '/gpu:%d' % i - for i in range(len(gpu_indices)) - for prefix in dev_prefixes - ] - else: - aux_devices = ['/job:localhost/cpu:0'] - # Auxiliary devices for hierarchical all-reduces. - aux_device_groups = group_device_names( - aux_devices, num_shards if alg_contains_shuffle else 1) - group_index = 0 - reduced_gv_list = [] - for grad_and_vars in zip(*tower_grads): - reduced_gv_list.append( - sum_grad_and_var_all_reduce( - grad_and_vars, num_workers, alg, gpu_indices, aux_devices - if is_hierarchical else aux_device_groups[group_index], num_shards)) - group_index = (group_index + 1) % len(aux_device_groups) - new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] - return new_tower_grads - - -def extract_ranges(index_list, range_size_limit=32): - """Extract consecutive ranges and singles from index_list. - - Args: - index_list: List of monotone increasing non-negative integers. - range_size_limit: Largest size range to return. If a larger - consecutive range exists, it will be returned as multiple - ranges. - - Returns: - (ranges, singles) where ranges is a list of [first, last] pairs of - consecutive elements in index_list, and singles is all of the - other elements, in original order. - """ - if not index_list: - return [], [] - first = index_list[0] - last = first - ranges = [] - singles = [] - for i in index_list[1:]: - if i == last + 1 and (last - first) <= range_size_limit: - last = i - else: - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - first = i - last = i - if last > first: - ranges.append([first, last]) - else: - singles.append(first) - return ranges, singles - - -GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') - - -def pack_range(key, packing, grad_vars, rng): - """Form the concatenation of a specified range of gradient tensors. - - Args: - key: Value under which to store meta-data in packing that will be used - later to restore the grad_var list structure. - packing: Dict holding data describing packed ranges of small tensors. - grad_vars: List of (grad, var) pairs for one tower. - rng: A pair of integers giving the first, last indices of a consecutive - range of tensors to be packed. - - Returns: - A tensor that is the concatenation of all the specified small tensors. - """ - to_pack = grad_vars[rng[0]:rng[1] + 1] - members = [] - variables = [] - restore_shapes = [] - with ops.name_scope('pack'): - for g, v in to_pack: - variables.append(v) - restore_shapes.append(g.shape) - with ops.device(g.device): - members.append(array_ops.reshape(g, [-1])) - packing[key] = GradPackTuple( - indices=range(rng[0], rng[1] + 1), - vars=variables, - shapes=restore_shapes) - with ops.device(members[0].device): - return array_ops.concat(members, 0) - - -def unpack_grad_tuple(gv, gpt): - """Unpack a previously packed collection of gradient tensors. - - Args: - gv: A (grad, var) pair to be unpacked. - gpt: A GradPackTuple describing the packing operation that produced gv. - - Returns: - A list of (grad, var) pairs corresponding to the values that were - originally packed into gv, maybe following subsequent operations like - reduction. - """ - elt_widths = [x.num_elements() for x in gpt.shapes] - with ops.device(gv[0][0].device): - with ops.name_scope('unpack'): - splits = array_ops.split(gv[0], elt_widths) - unpacked_gv = [] - for idx, s in enumerate(splits): - unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), - gpt.vars[idx])) - return unpacked_gv - - -def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): - """Concatenate small gradient tensors together for reduction. - - Args: - tower_grads: List of lists of (gradient, variable) tuples. - max_bytes: Int giving max number of bytes in a tensor that - may be considered small. - max_group: Int giving max number of small tensors that may be - concatenated into one new tensor. - - Returns: - new_tower_grads, packing where new_tower_grads is identical to - tower_grads except that all feasible small_tensors have been removed - from their places and concatenated into larger tensors that are - now in the front of the list for each tower, and packing contains - the data necessary to restore the tower_grads structure. - - Look through the first tower for gradients of the same type (float), - and small size, that are all sequential. For each such group, - replace by a new tensor that is a flattened concatenation. Note - that the corresponding variable will be absent, which doesn't matter - because it isn't used during all-reduce. - - Requires: - Every gv_list in towers must have isomorphic structure including identical - tensor sizes and types. - """ - small_indices = [] - large_indices = [] - for idx, (g, _) in enumerate(tower_grads[0]): - if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: - small_indices.append(idx) - else: - large_indices.append(idx) - small_ranges, small_singles = extract_ranges( - small_indices, range_size_limit=max_group) - large_indices = sorted(large_indices + small_singles) - num_gv = len(tower_grads[0]) - packing = {} - if small_ranges: - new_tower_grads = [] - for dev_idx, gv_list in enumerate(tower_grads): - assert len(gv_list) == num_gv - new_gv_list = [] - for r in small_ranges: - key = '%d:%d' % (dev_idx, len(new_gv_list)) - new_gv_list.append((pack_range(key, packing, gv_list, r), - 'packing_var_placeholder')) - for i in large_indices: - new_gv_list.append(gv_list[i]) - new_tower_grads.append(new_gv_list) - return new_tower_grads, packing - else: - return tower_grads, None - - -def unpack_small_tensors(tower_grads, packing): - """Undo the structure alterations to tower_grads done by pack_small_tensors. - - Args: - tower_grads: List of List of (grad, var) tuples. - packing: A dict generated by pack_small_tensors describing the changes - it made to tower_grads. - - Returns: - new_tower_grads: identical to tower_grads except that concatenations - of small tensors have been split apart and returned to their original - positions, paired with their original variables. - """ - if not packing: - return tower_grads - new_tower_grads = [] - num_devices = len(tower_grads) - num_packed = len(packing.keys()) // num_devices - for dev_idx, gv_list in enumerate(tower_grads): - gv_list = list(gv_list) - new_gv_list = gv_list[num_packed:] - for i in range(num_packed): - k = '%d:%d' % (dev_idx, i) - gpt = packing[k] - gv = unpack_grad_tuple(gv_list[i], gpt) - for gi, idx in enumerate(gpt.indices): - assert idx == gpt.indices[gi] - new_gv_list.insert(idx, gv[gi]) - new_tower_grads.append(new_gv_list) - return new_tower_grads - - -def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): - """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" - if any(isinstance(v, ops.IndexedSlices) for v in values): - return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access - else: - return accumulation_fn(values) - - -def divide_by_n_tensors_or_indexed_slices(value, n): - if isinstance(value, ops.IndexedSlices): - value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access - return ops.IndexedSlices( - value.values / n, value.indices, value.dense_shape) - else: - return value / n - - -def copy_tensor_or_indexed_slices_to_device(value, device): - with ops.device(device): - if isinstance(value, ops.IndexedSlices): - copied_values = array_ops.identity(value.values) - copied_indices = array_ops.identity(value.indices) - copied_shape = array_ops.identity(value.dense_shape) - result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) - else: - result = array_ops.identity(value) - return result - - -def contains_indexed_slices(value): - """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" - if isinstance(value, ops.IndexedSlices): - return True - elif isinstance(value, (list, tuple)) and value: - return any(contains_indexed_slices(v) for v in value) - elif isinstance(value, value_lib.DistributedValues): - return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access - elif isinstance(value, value_lib.MapOutput): - return contains_indexed_slices(value.get()) - else: - return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py deleted file mode 100644 index d25964fa41adc7b1c9164a4ffe49c4c5532f76ac..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for cross_tower_utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized - -from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import cross_tower_utils -from tensorflow.contrib.distribute.python import values as value_lib -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import math_ops -from tensorflow.python.training import device_util - - -class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): - - def _assert_values_equal(self, left, right): - self.assertAllEqual( - self.evaluate(ops.convert_to_tensor(left)), - self.evaluate(ops.convert_to_tensor(right))) - - @test_util.run_in_graph_and_eager_modes - def testAggregateTensors(self): - t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) - t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) - total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) - self._assert_values_equal(total, result) - - @test_util.run_in_graph_and_eager_modes - def testAggregateIndexedSlices(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) - result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) - self.assertIsInstance(result, ops.IndexedSlices) - self._assert_values_equal(total, result) - - @test_util.run_in_graph_and_eager_modes - def testDivideTensor(self): - t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) - n = 2 - expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) - self._assert_values_equal(expected, result) - - @test_util.run_in_graph_and_eager_modes - def testDivideIndexedSlices(self): - t = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - n = 2 - expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) - result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) - self.assertIsInstance(result, ops.IndexedSlices) - self._assert_values_equal(expected, result) - - @test_util.run_in_graph_and_eager_modes - def testIsIndexedSlices(self): - t = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_List(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_Tuple(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDevice(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - - @test_util.run_in_graph_and_eager_modes - def testContainsIndexedSlices_PerDeviceMapOutput(self): - t0 = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - t1 = math_ops._as_indexed_slices( - constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) - per_device = value_lib.PerDevice({ - "/gpu:0": value_lib.MapOutput([t0]), - "/cpu:0": value_lib.MapOutput([t1])}) - self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - required_gpus=1)) - def testCopyTensor(self): - with ops.device("/cpu:0"): - t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) - destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( - t, destination) - - self._assert_values_equal(t, result) - self.assertEqual(device_util.resolve(destination), - device_util.resolve(result.device)) - - @combinations.generate(combinations.combine( - mode=["graph", "eager"], - required_gpus=1)) - def testCopyIndexedSlices(self): - with ops.device("/cpu:0"): - t = math_ops._as_indexed_slices( - constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) - destination = "/gpu:0" - result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( - t, destination) - - self.assertIsInstance(result, ops.IndexedSlices) - self._assert_values_equal(t, result) - self.assertEqual(device_util.resolve(destination), - device_util.resolve(result.device)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index cc626c33bf8e282736f8e6e0c151e5a3d3f3244b..e17085628ba6d1dfc79839fd824801723f07a518 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -34,7 +34,7 @@ from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.summary.writer import writer_cache @@ -63,7 +63,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ], use_train_and_evaluate=[True, False])) def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): @@ -75,12 +77,12 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=True) eval_input_fn = self.dataset_input_fn( x={'x': data}, y=data, - batch_size=batch_size // len(distribution.worker_devices), + batch_size=batch_size // distribution.num_replicas_in_sync, shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( x={'x': data}, batch_size=batch_size, shuffle=False) @@ -126,8 +128,8 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, feature_spec = feature_column.make_parse_example_spec(feature_columns) serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) - export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) + export_dir = estimator.export_saved_model(tempfile.mkdtemp(), + serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) def tearDown(self): diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py index 157618f72ff2ea6dde171e7edb62ccaf7e1de516..b369a7fefe6f35cf5a9b64451419cf4f72a99471 100644 --- a/tensorflow/contrib/distribute/python/estimator_training_test.py +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -18,15 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import glob import json import os import sys import tempfile -import threading from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base @@ -43,11 +44,13 @@ from tensorflow.python.estimator import training as estimator_training from tensorflow.python.estimator.canned import dnn_linear_combined from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.feature_column import feature_column +from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary_iterator from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import session_manager + BATCH_SIZE = 10 LABEL_DIMENSION = 2 @@ -66,57 +69,19 @@ PS = dc._TaskType.PS original_run_std_server = dc._run_std_server -class MockOsEnv(dict): - - def __init__(self, *args): - self._thread_local = threading.local() - super(MockOsEnv, self).__init__(*args) - - def get(self, key, default): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.get(self._thread_local.dict, key, default) - else: - return dict.get(self, key, default) - - def __getitem__(self, key): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__getitem__(self._thread_local.dict, key) - else: - return dict.__getitem__(self, key) - - def __setitem__(self, key, val): - if not hasattr(self._thread_local, "dict"): - self._thread_local.dict = dict() - if key == "TF_CONFIG": - return dict.__setitem__(self._thread_local.dict, key, val) - else: - return dict.__setitem__(self, key, val) - - -class DistributeCoordinatorIntegrationTest(test.TestCase, - parameterized.TestCase): +class DistributeCoordinatorIntegrationTest( + multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" + super(DistributeCoordinatorIntegrationTest, cls).setUpClass() cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2, has_eval=True) def setUp(self): self._model_dir = tempfile.mkdtemp() - self._mock_os_env = MockOsEnv() - self._mock_context = test.mock.patch.object(os, "environ", - self._mock_os_env) super(DistributeCoordinatorIntegrationTest, self).setUp() - self._mock_context.__enter__() - - def tearDown(self): - self._mock_context.__exit__(None, None, None) - super(DistributeCoordinatorIntegrationTest, self).tearDown() def dataset_input_fn(self, x, y, batch_size, shuffle): @@ -139,6 +104,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, def _extract_loss_and_global_step(self, event_folder): """Returns the loss and global step in last event.""" event_paths = glob.glob(os.path.join(event_folder, "events*")) + self.assertNotEmpty( + event_paths, msg="Event file not found in dir %s" % event_folder) loss = None global_step_count = None @@ -189,7 +156,8 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, def _complete_flow(self, train_distribute, eval_distribute, - remote_cluster=None): + remote_cluster=None, + use_train_and_evaluate=True): estimator = self._get_estimator(train_distribute, eval_distribute, remote_cluster) @@ -197,10 +165,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, train_input_fn = self.dataset_input_fn( x={"x": DATA}, y=DATA, - batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync, shuffle=True) if eval_distribute: - eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync else: eval_batch_size = BATCH_SIZE eval_input_fn = self.dataset_input_fn( @@ -214,16 +182,37 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ] feature_columns = linear_feature_columns + dnn_feature_columns - estimator_training.train_and_evaluate( - estimator, - estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS), - estimator_training.EvalSpec( - name=EVAL_NAME, - input_fn=eval_input_fn, - steps=None, - exporters=self._get_exporter(EXPORTER_NAME, feature_columns), - start_delay_secs=0, - throttle_secs=1)) + eval_spec = estimator_training.EvalSpec( + name=EVAL_NAME, + input_fn=eval_input_fn, + steps=None, + exporters=self._get_exporter(EXPORTER_NAME, feature_columns), + start_delay_secs=0, + throttle_secs=1) + + if use_train_and_evaluate: + estimator_training.train_and_evaluate( + estimator, + estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS), + eval_spec) + else: + estimator.train(train_input_fn, max_steps=MAX_STEPS) + + latest_ckpt_path = estimator.latest_checkpoint() + metrics = estimator.evaluate(eval_input_fn, + checkpoint_path=latest_ckpt_path, + name=EVAL_NAME) + + # Export the eval result to files. + eval_result = estimator_training._EvalResult( + status=estimator_training._EvalStatus.EVALUATED, + metrics=metrics, + checkpoint_path=latest_ckpt_path) + evaluator = estimator_training._TrainingExecutor._Evaluator(estimator, + eval_spec, + None) + evaluator._export_eval_result(eval_result, True) + return estimator def _inspect_train_and_eval_events(self, estimator): @@ -259,32 +248,74 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, ]) self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + def _get_strategy_object(self, strategy_cls): + if strategy_cls == mirrored_strategy.CoreMirroredStrategy: + return strategy_cls(mirrored_strategy.all_local_devices()) + else: + return strategy_cls(num_gpus_per_worker=context.num_gpus()) + @combinations.generate( combinations.combine( mode=["graph"], train_distribute_cls=[ + collective_all_reduce_strategy.CollectiveAllReduceStrategy, mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, parameter_server_strategy.ParameterServerStrategy ], eval_distribute_cls=[ - None, mirrored_strategy.MirroredStrategy, - parameter_server_strategy.ParameterServerStrategy + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + parameter_server_strategy.ParameterServerStrategy, ], - required_gpus=1)) + required_gpus=[0, 1])) def test_complete_flow_standalone_client(self, train_distribute_cls, eval_distribute_cls): - try: - train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) - except TypeError: - train_distribute = train_distribute_cls(num_gpus_per_worker=2) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None + cluster_spec = copy.deepcopy(self._cluster_spec) + if (train_distribute_cls != + parameter_server_strategy.ParameterServerStrategy): + cluster_spec.pop("ps", None) + estimator = self._complete_flow(train_distribute, eval_distribute, + cluster_spec) + self._inspect_train_and_eval_events(estimator) + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + ], + eval_distribute_cls=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy, + ], + required_gpus=[0, 1])) + def test_estimator_standalone_client(self, train_distribute_cls, + eval_distribute_cls): + train_distribute = self._get_strategy_object(train_distribute_cls) + + if eval_distribute_cls: + eval_distribute = self._get_strategy_object(eval_distribute_cls) + else: + eval_distribute = None + + # We use the whole cluster for evaluation. + cluster = copy.deepcopy(self._cluster_spec) + cluster.pop("evaluator", None) + estimator = self._complete_flow( - train_distribute, eval_distribute, remote_cluster=self._cluster_spec) + train_distribute, eval_distribute, remote_cluster=cluster, + use_train_and_evaluate=False) self._inspect_train_and_eval_events(estimator) def _mock_run_std_server(self, *args, **kwargs): @@ -294,75 +325,56 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, self._barrier.wait() return ret - def _task_thread(self, train_distribute, eval_distribute, tf_config): - os.environ["TF_CONFIG"] = json.dumps(tf_config) + def _independent_worker_fn( + self, + train_distribute, + eval_distribute, + ): with test.mock.patch.object(dc, "_run_std_server", self._mock_run_std_server): self._complete_flow(train_distribute, eval_distribute) - def _run_task_in_thread(self, cluster_spec, task_type, task_id, - train_distribute, eval_distribute): - if task_type: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - else: - tf_config = { - "cluster": cluster_spec, - "task": { - "type": task_type, - "index": task_id - } - } - t = threading.Thread( - target=self._task_thread, - args=(train_distribute, eval_distribute, tf_config)) - t.start() - return t - - def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, - eval_distribute): - threads = {} - for task_type in cluster_spec.keys(): - threads[task_type] = [] - for task_id in range(len(cluster_spec[task_type])): - t = self._run_task_in_thread(cluster_spec, task_type, task_id, - train_distribute, eval_distribute) - threads[task_type].append(t) - return threads - @combinations.generate( combinations.combine( mode=["graph"], train_distribute_cls=[ + collective_all_reduce_strategy.CollectiveAllReduceStrategy, parameter_server_strategy.ParameterServerStrategy, ], eval_distribute_cls=[ None, mirrored_strategy.MirroredStrategy, - parameter_server_strategy.ParameterServerStrategy + mirrored_strategy.CoreMirroredStrategy, + parameter_server_strategy.ParameterServerStrategy, ], - required_gpus=1)) + required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_between_graph( self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls( - num_gpus_per_worker=context.num_gpus()) + if (context.num_gpus() < 2 and eval_distribute_cls == + collective_all_reduce_strategy.CollectiveAllReduceStrategy): + self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.") + + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None - cluster_spec = multi_worker_test_base.create_cluster_spec( - num_workers=3, num_ps=2, has_eval=True) - # 3 workers, 2 ps and 1 evaluator. - self._barrier = dc._Barrier(6) - - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + if (train_distribute_cls == parameter_server_strategy + .ParameterServerStrategy): + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=2, has_eval=True) + # 3 workers, 2 ps and 1 evaluator. + self._barrier = dc._Barrier(6) + else: + cluster_spec = multi_worker_test_base.create_cluster_spec( + num_workers=3, num_ps=0, has_eval=True) + # 3 workers and 1 evaluator. + self._barrier = dc._Barrier(4) + + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) for task_type, ts in threads.items(): if task_type == PS: continue @@ -375,15 +387,22 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, @combinations.generate( combinations.combine( mode=["graph"], - train_distribute_cls=[mirrored_strategy.MirroredStrategy], - eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], - required_gpus=1)) + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], + eval_distribute_cls=[ + None, + mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy + ], + required_gpus=[0, 1])) def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, eval_distribute_cls): - train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) + train_distribute = self._get_strategy_object(train_distribute_cls) if eval_distribute_cls: - eval_distribute = eval_distribute_cls() + eval_distribute = self._get_strategy_object(eval_distribute_cls) else: eval_distribute = None @@ -391,8 +410,9 @@ class DistributeCoordinatorIntegrationTest(test.TestCase, num_workers=3, num_ps=0, has_eval=True) # 3 workers and 1 evaluator. self._barrier = dc._Barrier(4) - threads = self._run_multiple_tasks_in_threads( - cluster_spec, train_distribute, eval_distribute) + threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn, + cluster_spec, train_distribute, + eval_distribute) threads[WORKER][0].join() threads[EVALUATOR][0].join() @@ -430,7 +450,8 @@ class RunConfigTest(test.TestCase): "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) def test_should_run_distribute_coordinator(self): """Tests that should_run_distribute_coordinator return a correct value.""" @@ -453,10 +474,12 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config_with_train_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) config_with_eval_distribute = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + eval_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertTrue( dc_training.should_run_distribute_coordinator( config_with_train_distribute)) @@ -469,26 +492,27 @@ class RunConfigTest(test.TestCase): {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + train_distribute=mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]))) self.assertFalse(dc_training.should_run_distribute_coordinator(config)) def test_init_run_config_duplicate_distribute(self): with self.assertRaises(ValueError): run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - train_distribute=mirrored_strategy.MirroredStrategy())) + train_distribute=mirrored_strategy.CoreMirroredStrategy())) with self.assertRaises(ValueError): run_config_lib.RunConfig( - eval_distribute=mirrored_strategy.MirroredStrategy(), + eval_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( - eval_distribute=mirrored_strategy.MirroredStrategy())) + eval_distribute=mirrored_strategy.CoreMirroredStrategy())) def test_init_run_config_none_distribute_coordinator_mode(self): # We don't use distribute coordinator for local training. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) dc_training.init_run_config(config, {}) self.assertIsNone(config._distribute_coordinator_mode) @@ -496,7 +520,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertIsNone(config._distribute_coordinator_mode) # When `train_distribute` is not specified, don't use distribute @@ -512,7 +536,7 @@ class RunConfigTest(test.TestCase): with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy()) + train_distribute=mirrored_strategy.CoreMirroredStrategy()) self.assertEqual(config._distribute_coordinator_mode, dc.CoordinatorMode.INDEPENDENT_WORKER) @@ -521,7 +545,7 @@ class RunConfigTest(test.TestCase): # `experimental.remote_cluster` is set use distribute coordinator with # STANDALONE_CLIENT mode. config = run_config_lib.RunConfig( - train_distribute=mirrored_strategy.MirroredStrategy(), + train_distribute=mirrored_strategy.CoreMirroredStrategy(), experimental_distribute=DistributeConfig( remote_cluster={"chief": ["fake_worker"]})) self.assertEqual(config._distribute_coordinator_mode, @@ -529,5 +553,15 @@ class RunConfigTest(test.TestCase): if __name__ == "__main__": + # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly. + orig_init = session_manager.SessionManager.__init__ + + def new_init(*args, **kwargs): + kwargs.pop("recovery_wait_secs", None) + kwargs["recovery_wait_secs"] = 0.5 + orig_init(*args, **kwargs) + + session_manager.SessionManager.__init__ = new_init + with test.mock.patch.object(sys, "exit", os._exit): test.main() diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a84ef041960e389c08246fc8a16df2300856d968..60fda996642464135fe1fb8c314bcf7f04d19362 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -20,18 +20,26 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.keras.optimizer_v2 import rmsprop + + NUM_CLASSES = 10 -def get_input_datasets(): +def get_input_datasets(use_bfloat16=False): """Downloads the MNIST dataset and creates train and eval dataset objects. + Args: + use_bfloat16: Boolean to determine if input should be cast to bfloat16 + Returns: Train dataset, eval dataset and input shape. """ # input image dimensions img_rows, img_cols = 28, 28 + cast_dtype = tf.bfloat16 if use_bfloat16 else tf.float32 # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() @@ -57,12 +65,13 @@ def get_input_datasets(): # train dataset train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() - train_ds = train_ds.shuffle(100) + train_ds = train_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) train_ds = train_ds.batch(64, drop_remainder=True) # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() + eval_ds = eval_ds.map(lambda x, y: (tf.cast(x, cast_dtype), y)) eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape @@ -97,23 +106,28 @@ def main(_): # Build the train and eval datasets from the MNIST data. Also return the # input shape which is constructed based on the `image_data_format` # i.e channels_first or channels_last. + tf.enable_eager_execution() + train_ds, eval_ds, input_shape = get_input_datasets() model = get_model(input_shape) # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or # the `devices` argument then all the GPUs available on the machine are used. - strategy = tf.contrib.distribute.MirroredStrategy() + # TODO(priyag): Use `tf.distribute.MirroredStrategy` once available. + strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/cpu:0']) + + optimizer = rmsprop.RMSProp(learning_rate=0.001) # Compile the model by passing the distribution strategy object to the # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed # based on the strategy instantiated. model.compile(loss=tf.keras.losses.categorical_crossentropy, - optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001), + optimizer=optimizer, metrics=['accuracy'], distribute=strategy) # Train the model with the train dataset. - model.fit(x=train_ds, epochs=20, steps_per_epoch=310) + model.fit(x=train_ds, epochs=20, steps_per_epoch=468) # Evaluate the model with the eval dataset. score = model.evaluate(eval_ds, steps=10, verbose=0) diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfd85bcc4f3784e2744fd876a7190cc9581d96a --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py @@ -0,0 +1,285 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show that DistributionStrategy works with canned Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import training +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column_lib as feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def dataset_input_fn(self, x, y, batch_size): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(1).batch(batch_size) + return dataset + + return input_fn + + @combinations.generate( + combinations.combine( + mode=['graph'], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus + ], + use_train_and_evaluate=[True, False])) + def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate): + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + train_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // distribution.num_replicas_in_sync) + eval_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // distribution.num_replicas_in_sync) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, batch_size=batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + session_config = config_pb2.ConfigProto( + log_device_placement=True, allow_soft_placement=True) + estimator = dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=label_dimension, + model_dir=self._model_dir, + dnn_optimizer=adam.Adam(0.001), + linear_optimizer=adam.Adam(0.001), + config=run_config.RunConfig( + train_distribute=distribution, + eval_distribute=distribution, + session_config=session_config)) + + num_steps = 2 + if use_train_and_evaluate: + scores, _ = training.train_and_evaluate( + estimator, training.TrainSpec(train_input_fn, max_steps=num_steps), + training.EvalSpec(eval_input_fn)) + else: + estimator.train(train_input_fn, steps=num_steps) + scores = estimator.evaluate(eval_input_fn) + + self.assertIn('loss', six.iterkeys(scores)) + + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testKerasOptimizerWithUnequalInput(self, distribution): + def create_fn(): + var = variables.Variable( + 2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM) + # grad for cpu is 1, grad for gpu is 2, avg grad is 1.5. + loss = math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) * var + optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2) + train_op = optimizer.minimize(loss, var_list=[var]) + m = optimizer.get_slot(var, 'm') + v = optimizer.get_slot(var, 'v') + return (var, m, v, train_op, optimizer.iterations) + + devices = ['/device:GPU:0', '/device:CPU:0'] + with distribution.scope(): + (var, m, v, op, counter) = distribution.call_for_each_replica(create_fn) + self.evaluate(variables.global_variables_initializer()) + var_val = [2.0, 2.0, 2.0] + self.assertAllClose( + var_val, + self.evaluate( + [distribution.read_var(var), + var.get(devices[0]), + var.get(devices[1])])) + self.assertAllClose([0, 0, 0], + self.evaluate([ + distribution.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + train_op = distribution.unwrap(op) + self.evaluate(train_op) + # m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2 + m_val = [1.2, 1.2, 1.2] + # assert slot variables in both replicas are the same. + self.assertAllClose( + m_val, + self.evaluate( + [distribution.read_var(m), + m.get(devices[0]), + m.get(devices[1])])) + # v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25 + v_val = [1.8, 1.8, 1.8] + self.assertAllClose( + v_val, + self.evaluate( + [distribution.read_var(v), + v.get(devices[0]), + v.get(devices[1])])) + # var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1) + # = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8 + var_val = [1.99, 1.99, 1.99] + self.assertAllClose( + var_val, + self.evaluate( + [distribution.read_var(var), + var.get(devices[0]), + var.get(devices[1])])) + self.assertAllClose([1, 1, 1], + self.evaluate([ + distribution.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + self.evaluate(train_op) + # m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5 + m_val = [1.44, 1.44, 1.44] + self.assertAllClose( + m_val, + self.evaluate( + [distribution.read_var(m), + m.get(devices[0]), + m.get(devices[1])])) + # v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25 + v_val = [2.16, 2.16, 2.16] + self.assertAllClose( + v_val, + self.evaluate( + [distribution.read_var(v), + v.get(devices[0]), + v.get(devices[1])])) + self.assertAllClose([2, 2, 2], + self.evaluate([ + distribution.read_var(counter), + counter.get(devices[0]), + counter.get(devices[1]) + ])) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph'])) + def testOptimizerWithKerasModelAndNumpyArrays(self, distribution): + + with self.cached_session(): + model = get_model() + optimizer = gradient_descent.SGD(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + model.fit( + inputs, + targets, + epochs=1, + batch_size=2, + verbose=0, + validation_data=(inputs, targets)) + model.evaluate(inputs, targets) + model.predict(inputs) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index 3511b7761ff4d8c995bfa40a1098b8e803f2a1b3..683cc89bfbae9c877ea6794d311ffc00c96c6937 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -24,24 +24,25 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import tpu_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import values +from tensorflow.python.eager import test from tensorflow.python.estimator import keras as keras_lib from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile -from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import gradient_descent from tensorflow.python.training import rmsprop - _RANDOM_SEED = 1337 _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) @@ -164,7 +165,9 @@ def get_multi_inputs_multi_outputs_data(): return (train_data, test_data) -def batch_wrapper(dataset, batch_size, distribution): +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) # TPUs currently require fully defined input shapes, drop_remainder ensures # the input will have fully defined shapes. if isinstance(distribution, tpu_strategy.TPUStrategy): @@ -197,30 +200,166 @@ def get_predict_dataset(distribution): return dataset -strategies = [combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.tpu_strategy_one_step] +def multi_input_output_model(): + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(5,), name='input_b') + # TODO(anjalisridhar): Change the output dimension of the second Dense layer + # once the iterator output validation issue has been fixed. + dense_1 = keras.layers.Dense(7, name='dense_1') + dense_2 = keras.layers.Dense(7, name='dense_2') + c = dense_1(a) + d = dense_2(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + model = keras.models.Model([a, b], [d, e]) + return model + + +def get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict): + """Generates the inputs for correctness check when enable Keras with DS.""" + training_epochs = 2 + global_batch_size = 64 + batch_size = global_batch_size + # TODO(b/118776054): Use global batch size for Keras/DS support. + use_per_core_batch_size = ( + with_distribution and + not distributed_training_utils.global_batch_size_supported( + with_distribution)) + if use_per_core_batch_size: + batch_size //= with_distribution.num_replicas_in_sync + + if use_numpy: + training_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + 'epochs': training_epochs, + 'shuffle': False, + } + + if use_validation_data: + eval_inputs = None + training_inputs['validation_data'] = (x_train, y_train) + else: + eval_inputs = { + 'batch_size': batch_size, + 'x': x_train, + 'y': y_train, + } + predict_inputs = { + 'x': np.array(x_predict, dtype=np.float32), + } + else: + # For dataset inputs, we do not pass batch_size to + # keras.fit/evaluate/predict. The batch size is part of the dataset. + train_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper( + train_dataset, batch_size, with_distribution, repeat=training_epochs) + + training_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'epochs': training_epochs, + 'shuffle': False, + 'steps_per_epoch': len(x_train) // global_batch_size, + } + if use_validation_data: + eval_inputs = None # Remove the eval_inputs + eval_dataset = dataset_ops.Dataset.from_tensor_slices( + (x_train, y_train)) + x = batch_wrapper(eval_dataset, batch_size, with_distribution) + training_inputs['validation_data'] = x + training_inputs['validation_steps'] = 5 + else: + eval_inputs = { + 'batch_size': None, + 'x': x, + 'y': None, + 'steps': 20, + } + + predict_batch_size = len(x_predict) + if use_per_core_batch_size: + predict_batch_size //= with_distribution.num_replicas_in_sync + predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, with_distribution) + predict_inputs = { + 'steps': 1, + 'x': predict_dataset, + } + + return training_inputs, eval_inputs, predict_inputs -def strategy_combinations(): +strategies_minus_tpu = [ + combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus] + +tpu_strategies = [ + combinations.tpu_strategy, # steps_per_run=2 + combinations.tpu_strategy_one_step] + + +def strategy_minus_tpu_combinations(): return combinations.combine( - distribution=strategies, + distribution=strategies_minus_tpu, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations(): + return combinations.combine( + distribution=tpu_strategies, mode=['graph']) +def all_strategy_combinations(): + return strategy_minus_tpu_combinations() + tpu_strategy_combinations() + + +# TODO(priyag): Add v2 optimizers here. def strategy_and_optimizer_combinations(): + return combinations.times( + all_strategy_combinations(), + combinations.combine( + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn])) + + +def strategy_and_input_combinations(): + return ( + combinations.times( + combinations.combine(distribution=strategies_minus_tpu), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]) + + combinations.combine(mode=['eager'], + use_numpy=[False], + use_validation_data=[False])) + + combinations.times( + combinations.combine(distribution=tpu_strategies), + combinations.combine(mode=['graph'], + use_numpy=[True, False], + use_validation_data=[True, False]))) + + +def strategy_for_numpy_input_combinations(): return combinations.combine( - distribution=strategies, - optimizer=[combinations.adagrad_optimizer_v1_fn, - combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn, - combinations.rmsprop_optimizer_v1_fn], + distribution=strategies_minus_tpu + tpu_strategies, mode=['graph']) -class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): +class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._base_dir = os.path.join(self.get_temp_dir(), @@ -228,17 +367,18 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) - self._dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() if os.path.isdir(self._base_dir): gfile.DeleteRecursively(self._base_dir) - def test_train_functional_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_functional_with_distribution_strategy(self, distribution): keras_model = simple_functional_model() keras_model.compile( loss='categorical_crossentropy', @@ -246,8 +386,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist, - eval_distribute=dist) + train_distribute=distribution, + eval_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -261,9 +401,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_train_sequential_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_train_sequential_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -271,7 +414,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) @@ -285,7 +428,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) - def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution): train_data, test_data = get_multi_inputs_multi_outputs_data() def train_input_fn(): @@ -315,14 +463,14 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): output_dict)).batch(16) self.do_test_multi_inputs_multi_outputs_with_input_fn( - train_input_fn, eval_input_fn) + distribution, train_input_fn, eval_input_fn) - def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, - eval_input_fn): + def do_test_multi_inputs_multi_outputs_with_input_fn( + self, distribution, train_input_fn, eval_input_fn): config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=self._dist) + train_distribute=distribution) with self.cached_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) @@ -332,9 +480,12 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) self.assertLess(eval_results['loss'], baseline_eval_results['loss']) - def test_keras_optimizer_with_distribution_strategy(self): - dist = mirrored_strategy.MirroredStrategy( - devices=['/device:GPU:0', '/device:GPU:1']) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph'])) + def test_keras_optimizer_with_distribution_strategy(self, distribution): keras_model = simple_sequential_model() keras_model.compile( loss='categorical_crossentropy', @@ -342,7 +493,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, - train_distribute=dist) + train_distribute=distribution) with self.cached_session(): est_keras = keras_lib.model_to_estimator(keras_model=keras_model, config=config) @@ -358,7 +509,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): class TestDistributionStrategyWithNumpyArrays(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_creating_var_with_numpy_arrays(self, distribution): with self.cached_session(): x = np.asarray(np.random.random((64, 3)), dtype=np.float32) @@ -367,7 +518,135 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # Verify that the numpy value is copied to the variable. self.assertAllEqual(x, val) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_no_batch_size(self, distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_20_samples = np.zeros((20, 3), dtype=np.float32) + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Default global batch size 32 for input with 64 samples run in 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # Computed global batch size 20 is lower than 32 if we pass less samples. + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_20_samples, steps=None, batch_size=None) + self.assertEqual(batch_size, 20 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Default global batch size 32 cannot be used with 63 samples. + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=None, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_no_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + # Input samples of different sizes + input_63_samples = np.zeros((63, 3), dtype=np.float32) + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed global batch size is correct for number of specified 1 step + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=1, batch_size=None) + self.assertEqual(batch_size, 64 // replica_scale_factor) + self.assertEqual(steps, 1) + + # Computed global batch size is correct for number of specified 2 steps + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=2, batch_size=None) + self.assertEqual(batch_size, 32 // replica_scale_factor) + self.assertEqual(steps, 2) + + # All samples can not be consumed in specified number of steps + with self.assertRaisesRegexp(ValueError, 'not divisible by steps'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=2, batch_size=None) + + # This cases is different for different strategies due to the + # difference in supported batch size being global or per-replica. + if replica_scale_factor == 1: + # Computed global batch size is correct even if not sharadable + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=3, batch_size=None) + self.assertEqual(batch_size, 21) + self.assertEqual(steps, 3) + else: + # Computed global batch size can not be sharded across replicas + with self.assertRaisesRegexp(ValueError, 'could not be sharded evenly ' + 'across the sync replicas'): + distributed_training_utils.get_input_params( + distribution, input_63_samples, steps=1, batch_size=None) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_no_steps_with_batch_size(self, + distribution): + # Calculate the per_replica_batch_size scaling factor for strategies + # that use per_core_batch_size + replica_scale_factor = 1.0 + if not distributed_training_utils.global_batch_size_supported(distribution): + replica_scale_factor = distribution.num_replicas_in_sync + + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=16) + self.assertEqual(batch_size, 16) + self.assertEqual(steps, 4 // replica_scale_factor) + + # Computed steps is correct for specified batch size + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=32) + self.assertEqual(batch_size, 32) + self.assertEqual(steps, 2 // replica_scale_factor) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=20) + + # Number of samples is not divisible by the global batch size + with self.assertRaisesRegexp(ValueError, 'not divisible by batch size'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=None, batch_size=3) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_calculating_input_params_with_steps_with_batch_size(self, + distribution): + with self.cached_session(): + input_64_samples = np.zeros((64, 3), dtype=np.float32) + + # No change to steps and batch size if both specified and feasible + steps, batch_size = distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=5, batch_size=3) + self.assertEqual(batch_size, 3) + self.assertEqual(steps, 5) + + # Number of samples is less than global batch size * steps + with self.assertRaisesRegexp(ValueError, 'less than samples required'): + distributed_training_utils.get_input_params( + distribution, input_64_samples, steps=10, batch_size=13) + + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_numpy_arrays(self, distribution): with self.cached_session(): model = get_model() @@ -398,29 +677,21 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(strategy_combinations()) + @combinations.generate(strategy_for_numpy_input_combinations()) def test_calling_model_with_nested_numpy_arrays(self, distribution): with self.cached_session(): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') - - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) - - model = keras.models.Model([a, b], [d, e]) + model = multi_input_output_model() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' model.compile(optimizer, loss, distribute=distribution) input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) - input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) inputs = [input_a_np, input_b_np] - output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) - output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32) + output_d_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) + output_e_np = np.asarray(np.random.random((64, 7)), dtype=np.float32) targets = [output_d_np, output_e_np] # Call fit with validation data @@ -440,11 +711,50 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, # with batch_size model.predict(inputs, batch_size=8) + @combinations.generate(combinations.combine( + distribution=strategies_minus_tpu, mode=['graph'])) + def test_numpy_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((20, 3), np.float32) + targets = np.zeros((20, 4), np.float32) + sample_weights = np.ones((20), np.float32) + + model.fit(inputs, targets, sample_weight=sample_weights, epochs=1, + steps_per_epoch=2, verbose=1) + + @combinations.generate(strategy_for_numpy_input_combinations()) + def test_flatten_predict_outputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # We take 6 input samples with each input having a dimension of 3 or 5. + input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) + input_b_np = np.asarray(np.random.random((6, 5)), dtype=np.float32) + inputs = [input_a_np, input_b_np] + + outs = model.predict(inputs, steps=1) + # `predict` a list that is equal in length to the number of model outputs. + # In this test our model has two outputs and each element of `outs` + # corresponds to all the samples of one of the model outputs. + self.assertLen(outs, 2) + # Each of the output samples have a dimension of 7. We should process all + # the available input samples(6). + self.assertAllEqual([6, 7], outs[0].shape) + self.assertAllEqual([6, 7], outs[1].shape) + class TestDistributionStrategyWithDatasets(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -463,32 +773,68 @@ class TestDistributionStrategyWithDatasets(test.TestCase, validation_data=dataset, validation_steps=2) model.predict(get_predict_dataset(distribution), steps=2) + @combinations.generate(all_strategy_combinations()) + def test_model_interleaved_eval_same_as_direct_eval(self, distribution): + with self.cached_session(): + user_controlled_model = get_model() + user_controlled_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + interleaved_model = get_model() + interleaved_model.set_weights(user_controlled_model.get_weights()) + interleaved_model.compile( + gradient_descent.GradientDescentOptimizer(0.001), + loss='mse', + metrics=['mae', keras.metrics.CategoricalAccuracy()], + distribute=distribution) + + dataset = get_dataset(distribution) + + # Call fit with validation interleaved + interleaved_output = interleaved_model.fit( + dataset, epochs=2, steps_per_epoch=2, verbose=1, + validation_data=dataset, validation_steps=2, shuffle=False) + + # Manually control the validation running after each epoch. + user_controlled_output = [] + for _ in range(2): + user_controlled_model.fit( + dataset, epochs=1, steps_per_epoch=2, verbose=1, shuffle=False) + user_controlled_output.append( + user_controlled_model.evaluate(dataset, steps=2)) + + self.assertEqual(interleaved_output.history['val_loss'], + [x[0] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_mean_absolute_error'], + [x[1] for x in user_controlled_output]) + self.assertEqual(interleaved_output.history['val_categorical_accuracy'], + [x[2] for x in user_controlled_output]) + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work # as clone_model's input_tensors argument only seems to accept list and not # tuples or dict. - def test_fit_with_tuple_and_dict_dataset_inputs(self): - with self.cached_session(): - a = keras.layers.Input(shape=(3,), name='input_a') - b = keras.layers.Input(shape=(3,), name='input_b') - - dense = keras.layers.Dense(4, name='dense') - c = dense(a) - d = dense(b) - e = keras.layers.Dropout(0.5, name='dropout')(c) - model = keras.models.Model([a, b], [d, e]) + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_fit_with_tuple_and_dict_dataset_inputs(self, distribution): + with self.cached_session(): + model = multi_input_output_model() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) input_a_np = np.random.random((10, 3)) - input_b_np = np.random.random((10, 3)) - output_d_np = np.random.random((10, 4)) - output_e_np = np.random.random((10, 4)) + input_b_np = np.random.random((10, 5)) + output_d_np = np.random.random((10, 7)) + output_e_np = np.random.random((10, 7)) # Test with tuples dataset_tuple = dataset_ops.Dataset.from_tensor_slices(( @@ -507,7 +853,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): model = get_model() @@ -537,35 +883,67 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model.evaluate(dataset, steps=2, verbose=1) model.predict(get_predict_dataset(distribution), steps=2) - def test_dataset_input_shape_validation(self): + @combinations.generate(strategy_minus_tpu_combinations()) + def test_dataset_with_sample_weights(self, distribution): + model = get_model() + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat() + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_wrong_input_shape(self, distribution): with self.cached_session(): model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, distribute=strategy) + model.compile(optimizer, loss, distribute=distribution) - # User forgets to batch the dataset - inputs = np.zeros((10, 3), dtype=np.float32) + # Wrong input shape + inputs = np.zeros((10, 5), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) + dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, + 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - # Wrong input shape - inputs = np.zeros((10, 5), dtype=np.float32) + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_dataset_no_batch_input_validation(self, distribution): + with self.cached_session(): + model = get_model() + + optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + # User forgets to batch the dataset + inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) - with self.assertRaisesRegexp(ValueError, - 'expected input to have shape'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) @combinations.generate(combinations.combine( @@ -587,51 +965,91 @@ class TestDistributionStrategyWithDatasets(test.TestCase, with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) - def test_learning_phase_value(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_learning_phase_value(self, distribution): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. with self.cached_session(): - x = keras.layers.Input(shape=(16,), name='input') - y = keras.layers.Dense(16)(x) + x = keras.layers.Input(shape=(1,), name='input') + y = keras.layers.Dense(1, kernel_initializer='ones')(x) z = keras.layers.Dropout(0.9999)(y) model = keras.Model(x, z) + initial_weights = model.get_weights() optimizer = gradient_descent.GradientDescentOptimizer(0.005) loss = 'mse' metrics = ['acc'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + batch_size = 8 + if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): + # CoreMirroredStrategy uses global batch size. + batch_size = 8 * distribution.num_replicas_in_sync - inputs = np.random.rand(10, 16) - targets = np.ones((10, 16), dtype=np.float32) + inputs = np.ones((10, 1), dtype=np.float32) + targets = np.ones((10, 1), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(8) + dataset = dataset.repeat().batch(batch_size) + hist = model.fit(dataset, epochs=1, steps_per_epoch=20, verbose=1) + self.assertAlmostEqual(hist.history['acc'][0], 0, 0) + + model.set_weights(initial_weights) + # TODO(psv/anjalisridhar): Enable these lines after we fix b/117431185. + # evaluate_output = model.evaluate(dataset, steps=20) + # self.assertAlmostEqual(evaluate_output[1], 1, 0) + + inputs = np.ones((10, 1), dtype=np.float32) + predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + + predict_dataset = predict_dataset.repeat().batch(batch_size) + output = model.predict(predict_dataset, steps=10) + # `predict` runs for 10 steps + ref_output = np.ones((160, 1), dtype=np.float32) + self.assertArrayNear(output, ref_output, 1e-1) + + @combinations.generate(strategy_minus_tpu_combinations()) + def testOptimizerWithCallbacks(self, distribution): + with self.cached_session(): + model = get_model() - hist = model.fit(dataset, epochs=5, steps_per_epoch=20, verbose=1) - self.assertEqual(hist.history['acc'][0], 1) + optimizer = gradient_descent_keras.SGD(0.01) + loss = 'mse' + model.compile(optimizer, loss, distribute=distribution) + + dataset = get_dataset(distribution) - evaluate_output = model.evaluate(dataset, steps=20) - self.assertEqual(evaluate_output[1], 0) + def schedule(_): + return 0.001 - predict_output = model.predict(dataset, steps=1) - self.assertNotEqual(np.mean(predict_output), 0) + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, + callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) + grouped_models = distribution.unwrap(model._grouped_model) + with distribution.scope(): + for m in grouped_models: + self.assertAllClose(0.001, keras.backend.get_value( + m.optimizer.lr), atol=1e-05, rtol=1e-05) class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): - def test_validating_dataset_input_tensors_with_shape_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_shape_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2)) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor shape details from the error message # since the order of the device and the corresponding input tensor shape # is not deterministic over different runs. @@ -640,17 +1058,21 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_validating_dataset_input_tensors_with_dtype_mismatch(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_validating_dataset_input_tensors_with_dtype_mismatch(self, + distribution): with self.cached_session(): - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64) x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b}) y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a}) - with strategy.scope(): + with distribution.scope(): # Removed device and input tensor dtype details from the error message # since the order of the device and the corresponding input tensor dtype # is not deterministic over different runs. @@ -659,21 +1081,23 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'distributed tensor inputs ' 'DistributedValues:.+'): distributed_training_utils.validate_distributed_dataset_inputs( - strategy, x, y) + distribution, x, y) - def test_unsupported_features(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_unsupported_features(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) # Test with validation split with self.assertRaisesRegexp( @@ -687,8 +1111,8 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): # Test with sample weight. sample_weight = np.random.random((10,)) with self.assertRaisesRegexp( - NotImplementedError, '`sample_weight` is currently not supported ' - 'when using DistributionStrategy.'): + ValueError, '`sample_weight` argument is not supported when input ' + '`x` is a dataset or a dataset iterator.'): model.fit( dataset, epochs=1, @@ -708,45 +1132,48 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) - def test_calling_with_unsupported_predefined_callbacks(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=['graph', 'eager'])) + def test_calling_with_unsupported_predefined_callbacks(self, distribution): with self.cached_session(): model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - dataset = get_dataset(strategy) + dataset = get_dataset(distribution) def schedule(_): return 0.001 with self.assertRaisesRegexp(ValueError, - 'LearningRateScheduler callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.LearningRateScheduler(schedule)]) with self.assertRaisesRegexp(ValueError, - 'ReduceLROnPlateau callback is not ' - 'supported with DistributionStrategy.'): + 'You must specify a Keras Optimizer V2 when ' + 'using'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, callbacks=[keras.callbacks.ReduceLROnPlateau()]) - with self.assertRaisesRegexp(ValueError, - 'histogram_freq in the TensorBoard callback ' - 'is not supported when using ' - 'DistributionStrategy.'): - model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, - callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) -class TestDistributionStrategyWithLossMasking(test.TestCase): +class TestDistributionStrategyWithLossMasking(test.TestCase, + parameterized.TestCase): # TODO(priyag): Enable all strategies for this test. Currently it does not # work for TPU due to some invalid datatype. - def test_masking(self): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_two_gpus], + mode=['graph', 'eager'])) + def test_masking(self, distribution): with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) @@ -755,12 +1182,9 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', - '/device:GPU:0']) - model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -772,7 +1196,7 @@ class TestDistributionStrategyWithLossMasking(test.TestCase): class TestDistributionStrategyWithNormalizationLayer( test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -804,7 +1228,7 @@ class TestDistributionStrategyWithNormalizationLayer( class TestDistributionStrategyCorrectness(test.TestCase, parameterized.TestCase): - @combinations.generate(strategy_combinations()) + @combinations.generate(all_strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -827,78 +1251,152 @@ class TestDistributionStrategyCorrectness(test.TestCase, distribute=distribution) batch_size = 64 - batch_size //= distribution.num_towers + if not distributed_training_utils.global_batch_size_supported( + distribution): + batch_size //= distribution.num_replicas_in_sync train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - self.assertEqual(history.history['binary_accuracy'], [1.0]) + history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10) + self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0]) - @combinations.generate(strategy_combinations()) - def test_correctness(self, distribution): + @combinations.generate(all_strategy_combinations()) + def test_eval_metrics_correctness(self, distribution): with self.cached_session(): - keras.backend.set_image_data_format('channels_last') - num_samples = 10000 + model = keras.Sequential() + model.add( + keras.layers.Dense( + 3, activation='relu', input_dim=4, kernel_initializer='ones')) + model.add( + keras.layers.Dense( + 1, activation='sigmoid', kernel_initializer='ones')) + model.compile( + loss='mae', + metrics=['accuracy', keras.metrics.BinaryAccuracy()], + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + distribute=distribution) + + # verify correctness of stateful and stateless metrics. + x = np.ones((100, 4)).astype('float32') + y = np.ones((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 1.) + self.assertEqual(outs[2], 1.) + + y = np.zeros((100, 1)).astype('float32') + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat() + dataset = batch_wrapper(dataset, 4, distribution) + outs = model.evaluate(dataset, steps=10) + self.assertEqual(outs[1], 0.) + self.assertEqual(outs[2], 0.) + + @combinations.generate(strategy_and_input_combinations()) + def test_correctness(self, distribution, use_numpy, use_validation_data): - # Train and predict datasets are created with the same input numpy arrays. + with self.cached_session(): + default_tolerance = 1e-5 + tol_table = {} + + if isinstance(distribution, (mirrored_strategy.MirroredStrategy, + mirrored_strategy.CoreMirroredStrategy)): + # TODO(b/119257215): Weights are not exactly the same, so use larger + # tolerance for now. Predict should be related to weights. + tol_table = { + 'weights_1': 1e-4, + 'weights_2': 1e-4, + 'predict_result_1': 1e-4, + } + + keras.backend.set_image_data_format('channels_last') + np.random.seed(_RANDOM_SEED) + random_seed.set_random_seed(_RANDOM_SEED) + + # Train, eval, and predict datasets are created with the same input numpy + # arrays. + # TODO(xiejw): Change this back to 10000, once we support final partial + # batch. + num_samples = 9984 x_train = np.random.rand(num_samples, 1) y_train = 3 * x_train x_train = x_train.astype('float32') y_train = y_train.astype('float32') + x_predict = [[1.], [2.], [3.], [4.]] # The model is built once and the initial weights are saved. # This is used to initialize the model for both the distribution and - # non-distribution run. - model = keras.Sequential() - model.add(keras.layers.Dense(1, input_shape=(1,))) + # non-distribution run. In addition, we add few non-linear layers to make + # it non-trivial. + def _create_model(): + model = keras.Sequential() + model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,))) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(10, activation='relu')) + model.add(keras.layers.Dense(1)) + return model + + model = _create_model() initial_weights = model.get_weights() + del model # avoid accident usage. - def fit_and_predict(with_distribution=None): + def fit_eval_and_predict(with_distribution=None): + model = _create_model() + # We have initialized the model to the same weight for the distribution + # and non-distribution run. model.set_weights(initial_weights) model.compile( loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), + optimizer=gradient_descent_keras.SGD(0.5), + metrics=['mse'], distribute=with_distribution) - batch_size = 64 - if with_distribution: - batch_size //= with_distribution.num_towers - train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, - y_train)) - train_dataset = batch_wrapper(train_dataset, batch_size, distribution) - # We have initialized the model to the same weight for the distribution - # and non-distribution run. If you want to initialize the model to - # random weights for each run, you need to run the model through the - # entire dataset at least once to ensure that the weights converge to - # the same value. - model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) - - weights = model.get_weights() - x_predict = [[1.], [2.], [3.], [4.]] - predict_batch_size = 4 - if with_distribution: - predict_batch_size //= with_distribution.num_towers - predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict) - predict_dataset = batch_wrapper(predict_dataset, - predict_batch_size, distribution) - predict_result = model.predict(predict_dataset, steps=1) - predict_result = np.reshape(predict_result, (4, 1)) - - return weights, predict_result - - wts_with_ds, predict_with_ds = fit_and_predict( - with_distribution=distribution) - wts_without_ds, predict_without_ds = fit_and_predict( - with_distribution=None) - - # Verify that the weights are the same within some limits of tolerance. - np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) - # Verify that the predicted outputs are the same within some limits of - # tolerance. - np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) - - -# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + training_inputs, eval_inputs, predict_inputs = ( + get_correctness_test_inputs(use_numpy, use_validation_data, + with_distribution, + x_train, y_train, x_predict)) + + result = {} + result['training_history_1'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_1'] = model.evaluate(**eval_inputs) + + result['weights_1'] = model.get_weights() + result['predict_result_1'] = model.predict(**predict_inputs) + + # Train and eval again to mimic user's flow. + + result['training_history_2'] = model.fit(**training_inputs).history + + if eval_inputs is not None: + result['eval_result_2'] = model.evaluate(**eval_inputs) + + result['weights_2'] = model.get_weights() + + return result + + results_with_ds = fit_eval_and_predict(with_distribution=distribution) + results_without_ds = fit_eval_and_predict(with_distribution=None) + + # Verify that the weights, training history, eval results, predict outputs + # are the same within some limits of tolerance. + for key in results_with_ds: + if (key.startswith('training_history') and + isinstance(distribution, tpu_strategy.TPUStrategy) and + distribution.extended.steps_per_run > 1): + # TODO(b/119894254): Enable this test for all cases once the + # underlying bug is fixed. + continue + + tolerance = tol_table.get(key, default_tolerance) + + self.assertAllClose( + results_with_ds[key], + results_without_ds[key], + atol=tolerance, + rtol=tolerance, + msg='Fail to assert {}.'.format(key)) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index ae4189eb1cb217f8a209b57f91a0ddb82e63dcd9..8ac659abe96370b751ed1556cc699fe20788a0fd 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -72,14 +72,14 @@ def _regression_dataset_fn(): "predictions": [1., .75, .25, 0.]}).repeat() -# TODO(priyag): Add TPU Strategy to this once metrics aggregate correctly using -# TowerLocalVariables on TPUs. Submit http://cl/208914352. def all_combinations(): return combinations.combine( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], mode=["graph"]) @@ -96,30 +96,32 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + dataset_fn).make_initializable_iterator() if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): - value, update = distribution.call_for_each_tower( - metric_fn, inputs) + value, update = distribution.call_for_each_replica( + metric_fn, args=inputs) ctx.set_non_tensor_output(name="value", output=value) return distribution.group(update) ctx = distribution.run_steps_on_dataset( - step_fn, iterator, iterations=distribution.steps_per_run) + step_fn, iterator, iterations=distribution.extended.steps_per_run) update = ctx.run_op value = ctx.non_tensor_outputs["value"] # In each run, we run multiple steps, and each steps consumes as many - # batches as number of towers. + # batches as number of replicas. batches_per_update = ( - distribution.num_towers * distribution.steps_per_run) + distribution.num_replicas_in_sync * + distribution.extended.steps_per_run) else: - value, update = distribution.call_for_each_tower( + value, update = distribution.call_for_each_replica( metric_fn, iterator.get_next()) update = distribution.group(update) # TODO(josh11b): Once we switch to using a global batch size for input, - # replace "distribution.num_towers" with "1". - batches_per_update = distribution.num_towers + # replace "distribution.num_replicas_in_sync" with "1". + batches_per_update = distribution.num_replicas_in_sync + self.evaluate(iterator.initializer) self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index ba147e78241e5ab45809e498e00debd45a2c49b4..f09483cb56b66fd4720ee71085203c14f1ccadc3 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -22,10 +22,10 @@ from absl.testing import parameterized import numpy from tensorflow.contrib.distribute.python import combinations -from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op @@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + def _get_iterator(self, ds): + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate(iterator.initializer) + return iterator + @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -56,14 +64,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_tower( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -93,19 +99,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True]))) - def testTrainNetworkByCallForEachTower(self, distribution, optimizer_fn, - use_callable_loss): + def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn, + use_callable_loss): with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.group( - distribution.call_for_each_tower( - model_fn, iterator.get_next(), run_concurrently=layer.built)) + distribution.call_for_each_replica( + model_fn, args=(iterator.get_next(),))) if not context.executing_eagerly(): with self.cached_session() as sess: @@ -153,14 +158,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): use_callable_loss=True, create_optimizer_inside_model_fn=True) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_tower( - model_fn, *inputs, run_concurrently=layer.built)) + distribution.call_for_each_replica(model_fn, args=inputs)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -179,11 +182,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], - "Adam": [ - "dense/kernel", "dense/bias", "beta1_power", "beta2_power", - "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", - "dense/bias/Adam_1" - ], "Adagrad": [ "dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad", "dense/bias" @@ -210,42 +208,34 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.combine( mode=["graph", "eager"], # TODO(isaprykin): Allow False here. Currently subsequent - # towers will re-execute UPDATE_OPS of previous towers. - update_ops_in_cross_tower_mode=[True])) + + # replicas will re-execute UPDATE_OPS of previous replicas. + update_ops_in_cross_replica_mode=[True])) + combinations.combine( distribution=[combinations.tpu_strategy], optimizer_fn=combinations.optimizers_v1, mode=["graph"], - update_ops_in_cross_tower_mode=[False]))) + update_ops_in_cross_replica_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm, update_ops_in_cross_tower_mode): - """Verifies that moving mean updates are reduced across towers.""" + renorm, update_ops_in_cross_replica_mode): + """Verifies that moving mean updates are reduced across replicas.""" with distribution.scope(): - num_towers = len(distribution.worker_devices) + num_replicas = distribution.num_replicas_in_sync model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, - batch_per_epoch=num_towers, + batch_per_epoch=num_replicas, momentum=momentum, renorm=renorm, - update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) - - # Make sure prefetching is disabled since that makes the - # specific input on each device to be non deterministic, and - # this test relies on specific input being on each device. - if isinstance(distribution, mirrored_strategy.MirroredStrategy): - self.assertFalse(distribution._prefetch_on_device) + update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): del ctx # Unused fetches = distribution.unwrap( - distribution.call_for_each_tower( - model_fn, *inputs, run_concurrently=batchnorm.built)) - if update_ops_in_cross_tower_mode: - fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + distribution.call_for_each_replica(model_fn, args=inputs)) + if update_ops_in_cross_replica_mode: + fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) return control_flow_ops.group(fetches) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -261,17 +251,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def averaged_batch_mean(i): # Each batch has shape [16, 8] where the ith element in jth list is - # (8 * j + i + tower_id * 100). So the batch mean in each tower is - # (60 + i + tower_id * 100). So here comes its batch mean over all - # towers: - return 60. + i + (num_towers - 1.) / 2. * 100. + # (8 * j + i + replica_id * 100). So the batch mean in each replica is + # (60 + i + replica_id * 100). So here comes its batch mean over all + # replicas: + return 60. + i + (num_replicas - 1.) / 2. * 100. for _ in range(10): run_step() moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is - # calculated over all towers. + # calculated over all replicas. for i, expected_moving_mean in enumerate(expected_moving_means): expected_moving_means[i] -= (( expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) @@ -296,7 +286,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution=[ combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus ]), combinations.combine( mode=["graph"], use_callable_loss=[True, False]) + @@ -332,14 +324,12 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) return dataset_ops.Dataset.zip((features, labels)).repeat() - def step_fn(ctx, x, y): + def step_fn(ctx, inputs): del ctx # Unused return distribution.group( - distribution.call_for_each_tower( - model_fn, x, y, run_concurrently=False)) + distribution.call_for_each_replica(model_fn, args=inputs)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -354,7 +344,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): run_step() v = all_vars[0] - self.assertTrue(all([v is vi for vi in all_vars[1:]])) + self.assertTrue(all(v is vi for vi in all_vars[1:])) weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w @@ -371,10 +361,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 # with sum loss reduction, or 10.6 with mean. if loss_reduction == losses_impl.Reduction.SUM: - # Note that the "distribution.num_towers" factor will go away once - # we split the input across towers, instead of pulling a complete - # batch of input per tower. - self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001) + # Note that the "distribution.num_replicas_in_sync" factor will go away + # once we split the input across replicas, instead of pulling a complete + # batch of input per replica. + self.assertNear(weight, 2 + 21.2 * distribution.num_replicas_in_sync, + 0.0001) else: # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) @@ -414,59 +405,58 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): train_op = optimizer.minimize(loss_fn) loss = loss_fn() output_context.set_last_step_output( - name="tower_loss_agg", + name="replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_non_tensor_output(key1, value1) return (train_op, loss) - def step_fn(output_context, *inputs): - (train_op, loss) = distribution.call_for_each_tower( - model_fn, output_context, *inputs, run_concurrently=False) + def step_fn(output_context, inputs): + (train_op, loss) = distribution.call_for_each_replica( + model_fn, args=(output_context,) + inputs) output_context.set_last_step_output( - name="cross_tower_loss_agg", + name="cross_replica_loss_reduced", output=loss, - aggregation=variables_lib.VariableAggregation.MEAN) + reduce_op=reduce_util.ReduceOp.MEAN) output_context.set_last_step_output( - name="cross_tower_loss_noagg", + name="cross_replica_loss_not_reduced", output=loss) return distribution.group(train_op) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): initial_loss = lambda: constant_op.constant(1e7) - # Initial values corresponding to aggregated losses are just single - # tensors. But for non aggregated losses, we need to have initial + # Initial values corresponding to reduced losses are just single + # tensors. But for non reduced losses, we need to have initial # values that are of the same structure as non reduced losses. In # MirroredStrategy, this will be a list of losses, in TPUStrategy # it will be single tensor. Using `broadcast` followed by `unwrap` # gives us the desired initial value structure. initial_loop_values = { - "tower_loss_agg": initial_loss(), - "cross_tower_loss_agg": initial_loss(), - "cross_tower_loss_noagg": + "replica_loss_reduced": initial_loss(), + "cross_replica_loss_reduced": initial_loss(), + "cross_replica_loss_not_reduced": distribution.unwrap(distribution.broadcast(initial_loss())) } ctx = distribution.run_steps_on_dataset( step_fn, iterator, iterations=2, initial_loop_values=initial_loop_values) - self.assertEqual({key1: [value1]}, ctx.non_tensor_outputs) + self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["tower_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_tower_loss_agg"], - aggregated=True, distribution=distribution) + loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], + reduced=True, distribution=distribution) self._verify_loss_output( initial_loss(), - loss_output=ctx.last_step_outputs["cross_tower_loss_noagg"], - aggregated=False, distribution=distribution) - return (ctx.run_op, ctx.last_step_outputs["tower_loss_agg"]) + loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], + reduced=False, distribution=distribution) + return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) self.evaluate(distribution.initialize()) if not context.executing_eagerly(): @@ -491,18 +481,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(error_is_not_increasing) - def _verify_loss_output(self, initial_loss, loss_output, aggregated, + def _verify_loss_output(self, initial_loss, loss_output, reduced, distribution): - if not aggregated: - self.assertEqual(distribution.num_towers, - len(distribution.unwrap(loss_output))) - loss_output = distribution.reduce( - aggregation=variables_lib.VariableAggregation.MEAN, - value=loss_output, destinations="/device:CPU:0") - - unwrapped_output = distribution.unwrap(loss_output) - self.assertEqual(1, len(unwrapped_output)) - loss_tensor = unwrapped_output[0] + if not reduced: + self.assertLen(distribution.unwrap(loss_output), + distribution.num_replicas_in_sync) + loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output) + else: + unwrapped_output = distribution.unwrap(loss_output) + self.assertLen(unwrapped_output, 1) + loss_tensor = unwrapped_output[0] self.assertEqual(initial_loss.dtype, loss_tensor.dtype) self.assertEqual(initial_loss.shape, loss_tensor.shape) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index a32424b316b003cc58ccf28fd968acb6a764a542..20f1a08d4261b931a9353738147fba7d7dff9225 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -12,300 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class MirroredStrategy implementing DistributionStrategy.""" +"""Contrib version of MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -from functools import partial -import threading +import functools -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import shared_variable_creator -from tensorflow.contrib.distribute.python import values -from tensorflow.python import pywrap_tensorflow -from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import coordinator -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.util import nest +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.distribute import values -# TODO(josh11b): Replace asserts in this file with if ...: raise ... - - -@contextlib.contextmanager -def _enter_graph(g): - if context.executing_eagerly(): - with g.as_default(), context.eager_mode(): - yield - else: - with g.as_default(): - yield - - -def _cpu_device(device): - cpu_device = tf_device.DeviceSpec.from_string(device) - cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) - return cpu_device.to_string() - - -class _RequestedStop(Exception): - pass - - -# _call_for_each_tower and _reduce_non_distributed_value are not members of -# MirroredStrategy so that they are generally not allowed to use anything -# specific to MirroredStrategy and thus can be shared with other distribution -# strategies. - - -# TODO(yuefengz): maybe create a common class for those who need to call this -# _call_for_each_tower. -def _call_for_each_tower(distribution, fn, *args, **kwargs): - """Run `fn` in separate threads, once per tower/worker device. - - Args: - distribution: the DistributionStrategy object. - fn: function to run (will be run once per device, each in its own thread). - *args: positional arguments for `fn` - **kwargs: keyword arguments for `fn`. - `"run_concurrently"`: Boolean indicating whether executions of `fn` - can be run concurrently (under eager execution only), defaults to - `True`. - - Returns: - Merged return value of `fn` across all towers. - - Raises: - RuntimeError: If fn() calls get_tower_context().merge_call() a different - number of times from the available devices. - """ - run_concurrently = kwargs.pop("run_concurrently", True) - if not context.executing_eagerly(): - # Lots of TF library code isn't thread-safe in graph mode, and - # there is little to be gained by turning on multithreading when - # constructing a graph. - run_concurrently = False - # Needed for per-thread device, etc. contexts in graph mode. - ops.get_default_graph().switch_to_thread_local() - elif run_concurrently is None: - run_concurrently = True - - coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) - - shared_variable_store = {} - - # TODO(isaprykin): Create these threads once instead of during every run() - # call. - threads = [] - for index, d in enumerate(distribution.worker_devices): - variable_creator_fn = shared_variable_creator.make_fn( - shared_variable_store, index) - t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access - distribution, coord, d, variable_creator_fn, fn, - *values.select_device(d, args), **values.select_device(d, kwargs)) - threads.append(t) - - for t in threads: - t.start() - - # When `fn` starts `should_run` event is set on _MirroredTowerThread - # (`MTT`) threads. The execution waits until - # `MTT.has_paused` is set, which indicates that either `fn` is - # complete or a `get_tower_context().merge_call()` is called. If `fn` is - # complete, then `MTT.done` is set to True. Otherwise, arguments - # of `get_tower_context().merge_call` from all paused threads are grouped - # and the `merge_fn` is performed. Results of the - # `get_tower_context().merge_call` are then set to `MTT.merge_result`. - # Each such `get_tower_context().merge_call` call returns the - # `MTT.merge_result` for that thread when `MTT.should_run` event - # is reset again. Execution of `fn` resumes. - - try: - with coord.stop_on_exception(): - all_done = False - while not all_done and not coord.should_stop(): - done = [] - if run_concurrently: - for t in threads: - t.should_run.set() - for t in threads: - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - else: - for t in threads: - t.should_run.set() - t.has_paused.wait() - t.has_paused.clear() - if coord.should_stop(): - return None - done.append(t.done) - if coord.should_stop(): - return None - all_done = all(done) - if not all_done: - if any(done): - raise RuntimeError("Some towers made a different number of " - "tower_context().merge_call() calls.") - # get_tower_context().merge_call() case - merge_args = values.regroup({t.device: t.merge_args for t in threads}) - merge_kwargs = values.regroup( - {t.device: t.merge_kwargs for t in threads}) - # We capture the name_scope of the MTT when we call merge_fn - # to ensure that if we have opened a name scope in the MTT, - # it will be respected when executing the merge function. We only - # capture the name_scope from the first MTT and assume it is - # the same for all other MTTs. - mtt_captured_name_scope = threads[0].captured_name_scope - with ops.name_scope(mtt_captured_name_scope): - merge_result = threads[0].merge_fn(distribution, *merge_args, - **merge_kwargs) - for t in threads: - t.merge_result = values.select_device(t.device, merge_result) - finally: - for t in threads: - t.should_run.set() - coord.join(threads) - - return values.regroup({t.device: t.main_result for t in threads}) - - -def _reduce_non_distributed_value(distribution, aggregation, value, - destinations): - """Reduce a non-DistributedValue `value` to `destinations`.""" - if isinstance(value, values.DistributedValues): - raise ValueError("You are passing a `DistributedValue` to " - "`_reduce_non_distributed_value`, which is not allowed.") - - # If the same value is present on all towers then the PerDevice value will - # be a single value. We also handle the case when `value` is a single value - # and equal to 0. - if value == 0: - return 0 - # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this - # essentially means that the same value should be on all destinations. - if aggregation in ( - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_TOWER): - return value - - cross_tower_ops_lib.validate_destinations(destinations) - # We do not support an aggregation type of SUM if the value is the same across - # all towers. We call this as part of assign functions for MirroredVariables - # and summing up identical values across towers is not clearly defined. - if (len(distribution.worker_devices) != 1 or - not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value %s cannot be reduced with " - "the given aggregation %s." % (value, aggregation)) - # TODO(anjalisridhar): Moves these methods to a device utility file? - devices = cross_tower_ops_lib.get_devices_from(destinations) - if len(devices) == 1: - with ops.device(devices[0]): - return array_ops.identity(value) - else: - value_updates = {} - for d in devices: - with ops.device(d): - value_updates[d] = array_ops.identity(value) - return values.Mirrored(value_updates) - - -def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring - # Figure out what collections this variable should be added to. - # We'll add the MirroredVariable to those collections instead. - collections = kwargs.pop("collections", None) - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - kwargs["collections"] = [] - - # Get synchronization value - synchronization = kwargs.get("synchronization", - variable_scope.VariableSynchronization.ON_WRITE) - if synchronization == variable_scope.VariableSynchronization.NONE: - raise ValueError("`NONE` variable synchronization mode is not " - "supported with `Mirrored` distribution strategy. Please" - " change the `synchronization` for variable: " + - kwargs["name"]) - elif synchronization == variable_scope.VariableSynchronization.ON_READ: - # Variables that are to be synced on read are tower local. - is_tower_local = True - kwargs["trainable"] = False - elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or - synchronization == variable_scope.VariableSynchronization.AUTO): - # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`. - is_tower_local = False - else: - raise ValueError("Invalid variable synchronization mode: " + - synchronization + " for variable: " + kwargs["name"]) - - # Get aggregation value - aggregation = kwargs.pop("aggregation", - variable_scope.VariableAggregation.NONE) - if aggregation not in ( - variable_scope.VariableAggregation.NONE, - variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN, - variable_scope.VariableAggregation.ONLY_FIRST_TOWER - ): - raise ValueError("Invalid variable aggregation mode: " + aggregation + - " for variable: " + kwargs["name"]) - - # Ignore user-specified caching device, not needed for mirrored variables. - kwargs.pop("caching_device", None) - - # TODO(josh11b,apassos): It would be better if variable initialization - # was never recorded on the tape instead of having to do this manually - # here. - with tape.stop_recording(): - index = real_mirrored_creator(devices, *args, **kwargs) - - if is_tower_local: - result = values.TowerLocalVariable(index, index[devices[0]], aggregation) - else: - result = values.MirroredVariable(index, index[devices[0]], aggregation) - - # Add the wrapped variable to the requested collections. - # The handling of eager mode and the global step matches - # ResourceVariable._init_from_args(). - if not context.executing_eagerly(): - g = ops.get_default_graph() - # If "trainable" is True, next_creator() will add the member variables - # to the TRAINABLE_VARIABLES collection, so we manually remove - # them and replace with the MirroredVariable. We can't set - # "trainable" to False for next_creator() since that causes functions - # like implicit_gradients to skip those variables. - if kwargs.get("trainable", True): - collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) - l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) - for v in index.values(): - l.remove(v) - g.add_to_collections(collections, result) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) - - return result +# pylint: disable=protected-access,invalid-name +_call_for_each_replica = mirrored_strategy._call_for_each_replica +_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value +_create_mirrored_variable = mirrored_strategy._create_mirrored_variable +all_local_devices = mirrored_strategy.all_local_devices +CoreMirroredStrategy = mirrored_strategy.MirroredStrategy +CoreMirroredExtended = mirrored_strategy.MirroredExtended +# pylint: enable=protected-access,invalid-name class MirroredStrategy(distribute_lib.DistributionStrategy): """Mirrors vars to distribute across multiple devices and machines. - This strategy uses one tower per device and sync replication for its multi-GPU - version. + *** contrib version *** + + This strategy uses one replica per device and sync replication for its + multi-GPU version. When `cluster_spec` is given by the `configure` method., it turns into the mulit-worker version that works on multiple workers with in-graph replication. @@ -329,12 +66,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): index. They all do similar things except for one worker checkpointing model variables, writing summaries, etc. in addition to its ordinary work. - The multi-worker version of this class maps one tower to one device on a - worker. It mirrors all model variables on all towers. For example, if you have - two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the - model variables on these 8 GPUs. Then like in MirroredStrategy, each tower - performs their computation with their own copy of variables unless in - cross-tower model where variable or tensor reduction happens. + The multi-worker version of this class maps one replica to one device on a + worker. It mirrors all model variables on all replicas. For example, if you + have two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of + the model variables on these 8 GPUs. Then like in MirroredStrategy, each + replica performs their computation with their own copy of variables unless in + cross-replica model where variable or tensor reduction happens. Args: devices: a list of device strings. @@ -344,489 +81,80 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus_per_worker: number of GPUs per worker. This is the same as `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be specified. - cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not + cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not set, the `configure` method will try to find the best one. - prefetch_on_device: optional boolean to specify whether to prefetch input - data to devices. auto_shard_dataset: whether to auto-shard the dataset when there are multiple workers. + cross_tower_ops: Deprecated alias for `cross_device_ops`. """ def __init__(self, devices=None, num_gpus=None, num_gpus_per_worker=None, - cross_tower_ops=None, - prefetch_on_device=None, - auto_shard_dataset=False): - super(MirroredStrategy, self).__init__() - - self._cross_tower_ops = cross_tower_ops - self._prefetch_on_device = prefetch_on_device - self._auto_shard_dataset = auto_shard_dataset - # Rememeber num GPUs which might be needed by `configure` method. + cross_device_ops=None, + auto_shard_dataset=False, + cross_tower_ops=None): + assert not (cross_device_ops and cross_tower_ops) if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") - if num_gpus is not None: - self._num_gpus = num_gpus - else: - self._num_gpus = num_gpus_per_worker - - self._initialize_local(self._num_gpus, devices) - - def _initialize_local(self, num_gpus, devices): - """Initializes the object for local training.""" - self._cluster_spec = None - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: - if num_gpus is None: - num_gpus = context.num_gpus() - if num_gpus == 0: - devices = ["/device:CPU:0"] - else: - devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") - self._num_gpus = num_gpus - # TODO(yuefengz): consider setting the default device. - - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)}) - - def _initialize_multi_worker(self, num_gpus, cluster_spec): - """Initializes the object for multi-worker training.""" - cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) - self._cluster_spec = cluster_spec - - self._workers = [] - for job in ["chief", "worker"]: - for task in range(len(cluster_spec.as_dict().get(job, []))): - self._workers.append("/job:%s/task:%d" % (job, task)) - if num_gpus is None: - raise ValueError("`num_gpus` is required if `cluster_spec` is given.") - if num_gpus > 0: - self._worker_device_map = { - worker: [ - device_util.canonicalize(worker + "/device:GPU:%d" % gpu) - for gpu in range(num_gpus) - ] for worker in self._workers - } - else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, "/device:CPU:0")] - for worker in self._workers - } + num_gpus = num_gpus_per_worker + extended = MirroredExtended(self, devices, num_gpus, + cross_device_ops or cross_tower_ops, + auto_shard_dataset) + super(MirroredStrategy, self).__init__(extended) - devices = nest.flatten(self._worker_device_map) - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. - self._default_device = self._workers[0] +class MirroredExtended(CoreMirroredExtended): + """Implementation of (contrib) MirroredStrategy.""" - assert devices, "Must specify at least one device." - assert len(set(devices)) == len(devices), ( - "No duplicates allowed in `devices` argument.") - # TODO(josh11b): Require at least 2 devices? - self._devices = [device_util.resolve(d) for d in devices] - self._canonical_device_set = set(self._devices) - self._device_index = values.PerDevice( - {d: i for i, d in enumerate(devices)}) - - def _create_variable(self, next_creator, *args, **kwargs): - """Create a mirrored variable. See `DistributionStrategy.scope`.""" - colocate_with = kwargs.pop("colocate_with", None) - devices = self._get_devices_from(colocate_with) - - def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring - index = {} - for i, d in enumerate(devices): - with ops.device(d): - if i > 0: - # Give replicas meaningful distinct names: - var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on towers with id > 0 to - # ensure that we ignore the name scope and instead use the given - # name as the absolute name of the variable. - kwargs["name"] = "%s/replica_%d/" % (var0name, i) - # Initialize replicas with the same value: - if context.executing_eagerly(): - kwargs["initial_value"] = array_ops.identity( - index[devices[0]].value()) - else: - def initial_value_fn(device=d): - with ops.device(device): - return array_ops.identity(index[devices[0]].initial_value) - kwargs["initial_value"] = initial_value_fn - with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) - assert not isinstance(v, values.DistributedVariable) - index[d] = v - return index - - return _create_mirrored_variable(devices, _real_mirrored_creator, *args, - **kwargs) + def __init__(self, + container_strategy, + devices=None, + num_gpus_per_worker=None, + cross_device_ops=None, + auto_shard_dataset=False): + if devices is None: + devices = mirrored_strategy.all_local_devices(num_gpus_per_worker) + elif num_gpus_per_worker is not None: + raise ValueError( + "Must only specify one of `devices` and `num_gpus_per_worker`.") + super(MirroredExtended, self).__init__(container_strategy, devices, + cross_device_ops) + self._auto_shard_dataset = auto_shard_dataset - def distribute_dataset(self, dataset_fn): - if self._cluster_spec: - return values.MultiWorkerDataset( - partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device, self._auto_shard_dataset) + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch. + + This implementation is different than the one in + `tf.distribute.MirroredStrategy` for purposes of backward compatibility. + We treat the incoming dataset's batch size as per replica batch size. + + Args: + dataset: `tf.data.Dataset` for input. + Returns: + An `InputIterator` which returns inputs for each step of the computation. + """ + if self._local_mode: + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, self._devices)] else: - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), self._devices, - self._prefetch_on_device) - - # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): - if initial_loop_values is None: - initial_loop_values = {} - initial_loop_values = nest.flatten(initial_loop_values) - - ctx = values.MultiStepContext() - def body(i, *args): - """A wrapper around `fn` to create the while loop body.""" - del args - fn_inputs = iterator.get_next() - if not isinstance(fn_inputs, tuple): - fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) - for (name, output) in ctx.last_step_outputs.items(): - # Convert all outputs to tensors, potentially from `DistributedValues`. - ctx.last_step_outputs[name] = self.unwrap(output) - flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) - with ops.control_dependencies([fn_result]): - return [i + 1] + flat_last_step_outputs - - # We capture the control_flow_context at this point, before we run `fn` - # inside a while_loop. This is useful in cases where we might need to exit - # these contexts and get back to the outer context to do some things, for - # e.g. create an op which should be evaluated only once at the end of the - # loop on the host. One such usage is in creating metrics' value op. - self._outer_control_flow_context = ( - ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - - cond = lambda i, *args: i < iterations - i = constant_op.constant(0) - loop_result = control_flow_ops.while_loop( - cond, body, [i] + initial_loop_values, name="", - parallel_iterations=1, back_prop=False, swap_memory=False, - return_same_structure=True) - del self._outer_control_flow_context - - ctx.run_op = control_flow_ops.group(loop_result) - - # Convert the last_step_outputs from a list to the original dict structure - # of last_step_outputs. - last_step_tensor_outputs = loop_result[1:] - last_step_tensor_outputs_dict = nest.pack_sequence_as( - ctx.last_step_outputs, last_step_tensor_outputs) - - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access - output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, wrap them in a Mirrored - # container, else in a PerDevice container. - if aggregation is variables_lib.VariableAggregation.NONE: - last_step_tensor_outputs_dict[name] = values.regroup( - {d: t for d, t in zip(self._devices, output)}, values.PerDevice) - else: - assert len(output) == 1 - last_step_tensor_outputs_dict[name] = output[0] - - ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access - return ctx - - def _broadcast(self, tensor, destinations): - # TODO(josh11b): In eager mode, use one thread per device, or async mode. - return self._get_cross_tower_ops().broadcast(tensor, destinations or - self._devices) - - def _call_for_each_tower(self, fn, *args, **kwargs): - return _call_for_each_tower(self, fn, *args, **kwargs) - - def map(self, map_over, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - index = {} - for i, m in enumerate(map_over): - d = self._devices[i % len(self._devices)] - with ops.device(d): - l = index.get(d, []) - l.append(fn(m, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs))) - index[d] = l - # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput - # in addition to PerDevice data. - return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) - - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): - del task_type, task_id - - if session_config: - session_config.isolate_session_state = True - - if cluster_spec: - self._initialize_multi_worker(self._num_gpus, cluster_spec) - - if self._cross_tower_ops is None: - if self._cluster_spec: - # It currently cannot detect the toplogy of remote workers. So we - # hard-code the multi-worker all-reduce algorithm for now. - if len(self._workers) == 1: - # The default is "nccl". - self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps() - else: - # The default is hierarchical reduce and broadcast. - self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( - self._workers, self._num_gpus) - else: - self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( - self._devices, session_config=session_config) - - def _get_cross_tower_ops(self): - if self._cross_tower_ops is None: - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) - return self._cross_tower_ops - - def _reduce(self, aggregation, value, destinations): - assert not isinstance(value, values.Mirrored) - if not isinstance(value, values.DistributedValues): - # This function handles reducing values that are not PerDevice or Mirrored - # values. For example, the same value could be present on all towers in - # which case `value` would be a single value or value could be 0. - return _reduce_non_distributed_value(self, aggregation, value, - destinations) - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: - value = value.get(self._devices[0]) - if isinstance(value, (int, float)): - return value - return self.broadcast(value, destinations) - return self._get_cross_tower_ops().reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: - return [self.broadcast(v.get(self._devices[0]), d) - for v, d in value_destination_pairs] - return self._get_cross_tower_ops().batch_reduce(aggregation, - value_destination_pairs) - - def _update(self, var, options, fn, *args, **kwargs): - # TODO(josh11b): In eager mode, use one thread per device. - assert isinstance(var, values.DistributedVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - updates = {} - for d, v in var._index.items(): # pylint: disable=protected-access - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - # If args and kwargs are not mirrored, the value is returned as is. - updates[d] = fn(v, - *values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - assert isinstance(colocate_with, list) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - # TODO(josh11b): In eager mode, use one thread per device. - updates = {} - for d in colocate_with: - name = "update_%d" % self._device_index.get(d) - with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): - updates[d] = fn(*values.select_device_mirrored(d, args), - **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - def read_var(self, tower_local_var): - """Read the aggregate value of a tower-local variable.""" - if isinstance(tower_local_var, values.TowerLocalVariable): - return tower_local_var._get_cross_tower() # pylint: disable=protected-access - assert isinstance(tower_local_var, values.Mirrored) - return array_ops.identity(tower_local_var.get()) - - def _unwrap(self, val): - if isinstance(val, values.DistributedValues): - # Return in a deterministic order. - if set(val.devices) == self._canonical_device_set: - return [val.get(device=d) for d in self._devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] - - def value_container(self, val): - return values.value_container(val) - - @property - def is_single_tower(self): - return len(self._devices) == 1 - - @property - def num_towers(self): - return len(self._devices) - - def _worker_device_index(self): - return self._device_index - - @property - def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._devices) - - @property - def parameter_devices(self): - return list(self._devices) - - @property - def between_graph(self): - return False - - @property - def should_init(self): - return True - - @property - def should_checkpoint(self): - return True - - @property - def should_save_summary(self): - return True - - def non_slot_devices(self, var_list): - del var_list - return list(self._devices) + worker_device_pairs = self._worker_devices + return values.DatasetIterator(dataset, worker_device_pairs) - def _get_devices_from(self, colocate_with=None): - if colocate_with is None: - return self._devices + def _distribute_dataset(self, dataset_fn): + if self._local_mode: + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), self._devices) else: - return cross_tower_ops_lib.get_devices_from(colocate_with) - - class _MirroredTowerThread(threading.Thread): - """A thread that runs() a function on a device.""" - - def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, - **kwargs): - super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access - self.coord = coord - self.distribution = dist - self.device = device - self.tower_id = dist.worker_devices.index(device) - self.variable_creator_fn = variable_creator_fn - # State needed to run and return the results of `fn`. - self.main_fn = fn - self.main_args = args - self.main_kwargs = kwargs - self.main_result = None - self.done = False - # State needed to run the next merge_call() (if any) requested via - # TowerContext. - self.merge_fn = None - self.merge_args = None - self.merge_kwargs = None - self.merge_result = None - self.captured_name_scope = None - # We use a thread.Event for the main thread to signal when this - # thread should start running (`should_run`), and another for - # this thread to transfer control back to the main thread - # (`has_paused`, either when it gets to a - # `get_tower_context().merge_call` or when `fn` returns). In - # either case the event starts cleared, is signaled by calling - # set(). The receiving thread waits for the signal by calling - # wait() and then immediately clearing the event using clear(). - self.should_run = threading.Event() - self.has_paused = threading.Event() - # These fields have to do with inheriting various contexts from the - # parent thread: - # pylint: disable=protected-access - self.context_mode = context.context()._eager_context.mode - if not context.context()._context_handle: - context.context()._initialize_handle_and_devices() - self.context_device_policy = ( - pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( - context.context()._context_handle)) - self.graph = ops.get_default_graph() - self._variable_creator_stack = self.graph._variable_creator_stack[:] - self._captured_var_scope = variable_scope.get_variable_scope() - # Adding a "/" at end lets us re-enter this scope later. - self._name_scope = self.graph.get_name_scope() - if self._name_scope: - self._name_scope += "/" - if self.tower_id > 0: - if not self._name_scope: - self._name_scope = "" - self._name_scope += "tower_%d/" % self.tower_id - - def run(self): - # pylint: disable=protected-access - self.graph._variable_creator_stack = self._variable_creator_stack - self.should_run.wait() - self.should_run.clear() - try: - if self.coord.should_stop(): - return - with self.coord.stop_on_exception(), \ - context.context()._mode(self.context_mode), \ - context.context().device_policy(self.context_device_policy), \ - _enter_graph(self.graph), \ - MirroredTowerContext(self.distribution, self.tower_id), \ - ops.device(self.device), \ - ops.name_scope(self._name_scope), \ - variable_scope.variable_scope( - self._captured_var_scope, reuse=self.tower_id > 0), \ - variable_scope.variable_creator_scope(self.variable_creator_fn): - self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) - self.done = True - finally: - self.has_paused.set() - - -class MirroredTowerContext(distribute_lib.TowerContext): - """TowerContext used in MirroredStrategy.call_for_each_tower(). - - Opened in `_MirroredTowerThread`, to allow the user to invoke - `MirroredStrategy`'s specific implementation of `merge_call()`, - which works by delegating the function and its arguments to - the main thread (the one that invoked - `MirroredStrategy.call_for_each_tower()`). - """ - - def _merge_call(self, fn, *args, **kwargs): - """Delegate to the main thread to actually perform merge_call().""" - t = threading.current_thread() # a _MirroredTowerThread - t.merge_fn = fn - t.merge_args = args - t.merge_kwargs = kwargs - t.captured_name_scope = t.graph.get_name_scope() - # Adding a "/" at end lets us re-enter this scope later. - if t.captured_name_scope: - t.captured_name_scope += "/" - t.has_paused.set() - t.should_run.wait() - t.should_run.clear() - if t.coord.should_stop(): - raise _RequestedStop() - return t.merge_result + return values.MultiWorkerDataset( + functools.partial(self._call_dataset_fn, dataset_fn), + self._worker_devices, + auto_shard=self._auto_shard_dataset) + # TODO(priyag): Delete this once all strategies use global batch size. @property - def device(self): - distribute_lib.require_tower_context(self) - return self._distribution_strategy.worker_devices[self._tower_id] + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index eeac528329a239f6a8a68a72c44272566b1d83d1..36be5c83f8bafb6c934d1d7682b5227b1f71c089 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -20,269 +20,268 @@ from __future__ import print_function import sys +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import func_graph from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training +from tensorflow.python.keras.layers import core as keras_core from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] -class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], + mode=["graph", "eager"])) +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase, + parameterized.TestCase): - def _get_distribution_strategy(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - if context.num_gpus() > 1: - devices = ["/device:GPU:0", "/device:GPU:1"] - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - return mirrored_strategy.MirroredStrategy(devices) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) - def testMinimizeLossEager(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_minimize_loss_eager(self._get_distribution_strategy()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) - def testMinimizeLossGraph(self): - soft_placement = not GPU_TEST - print("testMinimizeLossGraph soft_placement:", soft_placement) - self._test_minimize_loss_graph( - self._get_distribution_strategy(), soft_placement=soft_placement) - - def testMapReduce(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_device_index(self._get_distribution_strategy()) - - def testTowerId(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_tower_id(self._get_distribution_strategy()) - - def testNumTowers(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self.assertEqual(2, self._get_distribution_strategy().num_towers) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testRunRegroupError(self): - - def run_fn(device_id): + def testNumReplicasInSync(self, distribution): + self.assertEqual(2, distribution.num_replicas_in_sync) + + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) + + def testRunRegroupError(self, distribution): + def run_fn(): + replica_id = int(self.evaluate(_replica_id())) # Generates a list with different lengths on different devices. # Will fail in _regroup() (if more than one device). - return list(range(device_id)) - - dist = self._get_distribution_strategy() - with dist.scope(), self.assertRaises(AssertionError): - dist.call_for_each_tower(run_fn, dist.worker_device_index) - - @test_util.run_in_graph_and_eager_modes - def testReduceToCpu(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return device_id - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - expected = sum(range(len(dist.worker_devices))) - self.assertEqual(expected, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes - def testReduceOnlyFirstTowerUpdates(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - def run_fn(device_id): - return constant_op.constant(3 + 5 * device_id) - - dist = self._get_distribution_strategy() - with dist.scope(): - result = dist.call_for_each_tower(run_fn, dist.worker_device_index) - reduced = dist.reduce( - variable_scope.VariableAggregation.ONLY_FIRST_TOWER, - result, - destinations="/device:CPU:0") - unwrapped = dist.unwrap(reduced) - self.assertEqual(1, len(unwrapped)) - self.assertEqual(3, self.evaluate(unwrapped[0])) - - @test_util.run_in_graph_and_eager_modes() - def testReduceToMultipleDestinations(self): - if not GPU_TEST: - self.skipTest("Not GPU test") - - devices = ["/device:GPU:0"] - if GPU_TEST: - self.assertGreater(context.num_gpus(), 0) - print(self.id().split(".")[-1], "devices:", ", ".join(devices)) - - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - reduced = dist.reduce( - variable_scope.VariableAggregation.SUM, - 1.0, - destinations=["/device:CPU:0", "/device:GPU:0"]) - unwrapped = dist.unwrap(reduced) - self.assertEqual(2, len(unwrapped)) - self.assertEqual(1.0, self.evaluate(unwrapped[0])) + return list(range(replica_id)) + + with distribution.scope(), self.assertRaises(AssertionError): + distribution.extended.call_for_each_replica(run_fn) + + def testReduceToCpu(self, distribution): + with distribution.scope(): + result = distribution.extended.call_for_each_replica(_replica_id) + reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result) + expected = sum(range(distribution.num_replicas_in_sync)) + self.assertEqual(expected, self.evaluate(reduced)) + + def testMakeInputFnIterator(self, distribution): + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=2, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, + expected_values) + + def testGlobalStepUpdate(self, distribution): + self._test_global_step_update(distribution) + + +def one_device_combinations(): + return combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_cpu, + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_cpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph", "eager"]) + + +class MirroredOneDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + @combinations.generate(one_device_combinations()) + def testMinimizeLoss(self, distribution): + if context.executing_eagerly(): + self._test_minimize_loss_eager(distribution) + else: + self._test_minimize_loss_graph(distribution) + + @combinations.generate(one_device_combinations()) + def testReplicaId(self, distribution): + self._test_replica_id(distribution) + + @combinations.generate(one_device_combinations()) + def testCallAndMergeExceptions(self, distribution): + self._test_call_and_merge_exceptions(distribution) + + +class MirroredStrategyVariableCreatorStackTest( + test.TestCase, parameterized.TestCase): + + @combinations.generate(combinations.combine( + distribution=[combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) + def testCreatorStacksAreThreadLocal(self, distribution): + def model_fn(): + replica_id_str = str(self.evaluate(_replica_id())) + + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + replica_id_str + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + # This will pause the current thread, and execute the other thread. + ds_context.get_replica_context().merge_call(lambda _: _) + return v + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + distribution.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = distribution.extended.call_for_each_replica(model_fn) + result = distribution.unwrap(result) + expected = ("main_thread:thread_0", "main_thread:thread_1") + self.assertEqual(expected, result) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredStrategyVariableCreationTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True + # TODO(priyag): Modify more tests to use this helper and check more + # properties. + def _test_mv_properties(self, var, name): + self.assertIsInstance(var, values.MirroredVariable) + self.assertEqual(name, var.name) + for d in var.devices: + self.assertEqual(d, var.get(d).device) - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + def testVariableInFuncGraph(self, distribution): + def model_fn(): + v = variable_scope.variable(2.0, name="bar") + ds_context.get_replica_context().merge_call(lambda _: _) + return v + + with func_graph.FuncGraph("fg").as_default(), distribution.scope(): + v1 = variable_scope.variable(1.0, name="foo") + v2 = distribution.extended.call_for_each_replica(model_fn) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSingleVariable(self): - self._skip_eager_if_gpus_less_than(1) + self._test_mv_properties(v1, "foo:0") + self._test_mv_properties(v2, "bar:0") + def testSingleVariable(self, distribution): def model_fn(): # This variable should be created only once across the threads because of - # special variable_creator functions used by `dist.call_for_each_tower`. + # special variable_creator functions used by + # `distribution.extended.call_for_each_replica`. v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testUnnamedVariable(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "foo:0") + def testUnnamedVariable(self, distribution): def model_fn(): v = variable_scope.variable(1.0) - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertIsInstance(result, values.MirroredVariable) - # Default name of "Variable" will be used. - self.assertEquals("Variable:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariables(self): - self._skip_eager_if_gpus_less_than(1) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self._test_mv_properties(result, "Variable:0") + def testMultipleVariables(self, distribution): def model_fn(): vs = [] for i in range(5): vs.append(variable_scope.variable(1.0, name="foo" + str(i))) - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for i, v in enumerate(result): - self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals("foo" + str(i) + ":0", v.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testMultipleVariablesWithSameCanonicalName(self): - self._skip_eager_if_gpus_less_than(1) + self._test_mv_properties(v, "foo" + str(i) + ":0") + def testMultipleVariablesWithSameCanonicalName(self, distribution): def model_fn(): vs = [] vs.append(variable_scope.variable(1.0, name="foo/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar")) vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) vs.append(variable_scope.variable(1.0, name="foo/bar_1")) - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return vs - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) for v in result: self.assertIsInstance(v, values.MirroredVariable) - self.assertEquals(4, len(result)) - self.assertEquals("foo/bar:0", result[0].name) - self.assertEquals("foo_1/bar:0", result[1].name) - self.assertEquals("foo_1/bar_1:0", result[2].name) - self.assertEquals("foo/bar_1:0", result[3].name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testVariableWithSameCanonicalNameAcrossThreads(self): - self._skip_eager_if_gpus_less_than(1) - - def model_fn(device_id): - v = variable_scope.variable(1.0, name="foo_" + str(device_id)) - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) - return v + self.assertEqual(4, len(result)) + self.assertEqual("foo/bar:0", result[0].name) + self.assertEqual("foo_1/bar:0", result[1].name) + self.assertEqual("foo_1/bar_1:0", result[2].name) + self.assertEqual("foo/bar_1:0", result[3].name) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) + def testVariableWithSameCanonicalNameAcrossThreads(self, distribution): + def model_fn(): + replica_id = self.evaluate(_replica_id()) + v = variable_scope.variable(1.0, name="foo_" + str(replica_id)) + ds_context.get_replica_context().merge_call(lambda _: _) + return v - with dist.scope(): - result = dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) self.assertIsInstance(result, values.MirroredVariable) # The resulting mirrored variable will use the name from the first device. - self.assertEquals("foo_0:0", result.name) + self.assertEqual("foo_0:0", result.name) - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithLayers(self): - self._skip_eager_if_gpus_less_than(1) + def testWithLayers(self, distribution): def model_fn(features): with variable_scope.variable_scope("common"): layer1 = core.Dense(1) @@ -290,41 +289,40 @@ class MirroredStrategyVariableCreationTest(test.TestCase): layer2 = core.Dense(1) layer2(features) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) layer3 = core.Dense(1) layer3(features) return [(layer1.kernel, layer1.bias), (layer2.kernel, layer2.bias), (layer3.kernel, layer3.bias)] - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - features = dist.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) - ).make_one_shot_iterator().get_next() + ds = distribution.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate([iterator.initializer]) - with dist.scope(): - result = dist.call_for_each_tower( - model_fn, features, run_concurrently=False) + features = iterator.get_next() + + with distribution.scope(): + result = distribution.extended.call_for_each_replica( + model_fn, args=(features,)) suffixes = ["", "_1", "_2"] for (kernel, bias), suffix in zip(result, suffixes): self.assertIsInstance(kernel, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertEqual("common/dense" + suffix + "/kernel:0", kernel.name) self.assertIsInstance(bias, values.MirroredVariable) - self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/dense" + suffix + "/bias:0", bias.name) + def testWithVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.variable(1.0, name="var0", aggregation=None) with variable_scope.variable_scope("common"): v1 = variable_scope.variable(1.0, name="var1") # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.variable( 1.0, name="var2", @@ -338,37 +336,31 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): v = variable_scope.variable(1.0, name="var-main0") - self.assertEquals("var-main0:0", v.name) + self.assertEqual("var-main0:0", v.name) - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("var0:0", v0.name) + self.assertEqual("var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("common/var1:0", v1.name) - self.assertIsInstance(v2, values.TowerLocalVariable) - self.assertEquals("common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation) + self.assertEqual("common/var1:0", v1.name) + self.assertIsInstance(v2, values.ReplicaLocalVariable) + self.assertEqual("common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testWithGetVariableAndVariableScope(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, v3.aggregation) + def testWithGetVariableAndVariableScope(self, distribution): def model_fn(): v0 = variable_scope.get_variable("var0", [1]) with variable_scope.variable_scope("common"): v1 = variable_scope.get_variable("var1", [1]) # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) v2 = variable_scope.get_variable( "var2", [1], synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -380,35 +372,30 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1, v2, v3 - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): with variable_scope.variable_scope("main"): v = variable_scope.get_variable("var-main0", [1]) - self.assertEquals("main/var-main0:0", v.name) + self.assertEqual("main/var-main0:0", v.name) - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(4, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(4, len(result)) v0, v1, v2, v3 = result self.assertIsInstance(v0, values.MirroredVariable) - self.assertEquals("main/var0:0", v0.name) + self.assertEqual("main/var0:0", v0.name) self.assertIsInstance(v1, values.MirroredVariable) - self.assertEquals("main/common/var1:0", v1.name) - self.assertIsInstance(v2, values.TowerLocalVariable) - self.assertEquals("main/common/var2:0", v2.name) - self.assertEquals(variable_scope.VariableAggregation.SUM, - v2.aggregation) + self.assertEqual("main/common/var1:0", v1.name) + self.assertIsInstance(v2, values.ReplicaLocalVariable) + self.assertEqual("main/common/var2:0", v2.name) + self.assertEqual(variable_scope.VariableAggregation.SUM, + v2.aggregation) self.assertIsInstance(v3, values.MirroredVariable) - self.assertEquals("main/common/var3:0", v3.name) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - v3.aggregation) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testOnlyFirstTowerUpdatesVariables(self): - self._skip_eager_if_gpus_less_than(1) + self.assertEqual("main/common/var3:0", v3.name) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + v3.aggregation) + def testOnlyFirstReplicaUpdatesVariables(self, distribution): def create_fn(): - aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER + aggregation = variable_scope.VariableAggregation.ONLY_FIRST_REPLICA v0 = variable_scope.variable( 2.0, name="on_read", @@ -422,71 +409,73 @@ class MirroredStrategyVariableCreationTest(test.TestCase): return v0, v1 devices = ["/device:GPU:0", "/device:CPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): - v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False) + with distribution.scope(): + v0, v1 = distribution.extended.call_for_each_replica(create_fn) self.evaluate(v0.initializer) self.assertEqual(2.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0))) self.evaluate(v1.initializer) self.assertEqual(3.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1))) + + def replica_id_plus_one(): + return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32) # Update using the assign_add member function. - def update_member_fn(device_id): - update0 = v0.assign_add(5.0 * (device_id + 1)) - update1 = v1.assign_add(7.0 * (device_id + 1)) + def update_member_fn(): + update0 = v0.assign_add(5.0 * replica_id_plus_one()) + update1 = v1.assign_add(7.0 * replica_id_plus_one()) return update0, update1 - update0a, update1a = dist.call_for_each_tower( - update_member_fn, dist.worker_device_index, run_concurrently=False) + update0a, update1a = distribution.extended.call_for_each_replica( + update_member_fn) # Update "sync on read" variable. - self.evaluate(dist.group(update0a)) + self.evaluate(distribution.group(update0a)) self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0]))) # Writes are not synchronized for "sync on read" variables, # so device[1] can end up with a different value. self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1]))) # Always reads from device 0. - self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1a)) + self.evaluate(distribution.group(update1a)) self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0]))) # Writes are synchronized for v1, only the argument to assign_add on # device[0] is used. self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1))) + self.assertEqual(3.0 + 7.0, self.evaluate( + distribution.extended.read_var(v1))) # Update using state_ops.assign_add global function. - def update_state_ops_fn(device_id): - update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1)) - update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1)) + def update_state_ops_fn(): + update0 = state_ops.assign_add(v0, 11.0 * replica_id_plus_one()) + update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one()) return update0, update1 - update0b, update1b = dist.call_for_each_tower( - update_state_ops_fn, dist.worker_device_index, run_concurrently=False) - self.evaluate(dist.group(update0b)) + update0b, update1b = distribution.extended.call_for_each_replica( + update_state_ops_fn) + self.evaluate(distribution.group(update0b)) # Update "sync on read" variable. self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0]))) self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1]))) - self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0))) + self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate( + distribution.extended.read_var(v0))) # Update "sync on write" variable. - self.evaluate(dist.group(update1b)) + self.evaluate(distribution.group(update1b)) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0]))) self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1]))) - self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1))) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate( + distribution.extended.read_var(v1))) + + def testNoneSynchronizationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -495,12 +484,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): "v", [1], synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testNoneSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testNoneSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "`NONE` variable synchronization mode is not " "supported with `Mirrored` distribution strategy. Please change " @@ -510,23 +495,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): name="v", synchronization=variable_scope.VariableSynchronization.NONE) - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidSynchronizationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidSynchronizationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable synchronization mode: Invalid for " "variable: v"): variable_scope.variable(1.0, name="v", synchronization="Invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithGetVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithGetVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -535,12 +512,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testInvalidAggregationWithVariable(self): - self._skip_eager_if_gpus_less_than(1) - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + def testInvalidAggregationWithVariable(self, distribution): + with distribution.scope(): with self.assertRaisesRegexp( ValueError, "Invalid variable aggregation mode: invalid for " "variable: v"): @@ -550,53 +523,28 @@ class MirroredStrategyVariableCreationTest(test.TestCase): synchronization=variable_scope.VariableSynchronization.ON_WRITE, aggregation="invalid") - @test_util.run_in_graph_and_eager_modes(config=config) - def testThreeDevices(self): - self._skip_eager_if_gpus_less_than(2) - - def model_fn(): - v = variable_scope.variable(1.0, name="foo") - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) - return v - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) - - with dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertIsInstance(result, values.MirroredVariable) - self.assertEquals("foo:0", result.name) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testNonMatchingVariableCreation(self): - self._skip_eager_if_gpus_less_than(1) - + def testNonMatchingVariableCreation(self, distribution): def model_fn(name): v = variable_scope.variable(1.0, name=name) - distribution_strategy_context.get_tower_context().merge_call(lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): + with distribution.scope(): names = values.DistributedValues({ "/device:CPU:0": "foo", "/device:GPU:0": "bar" }) with self.assertRaises(RuntimeError): - _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTowerLocalVariable(self): - self._skip_eager_if_gpus_less_than(1) + _ = distribution.extended.call_for_each_replica(model_fn, args=(names,)) + def testReplicaLocalVariable(self, distribution): all_v_sum = {} all_v_mean = {} components_sum = {} components_mean = {} - def model_fn(device_id): + def model_fn(): + replica_id = self.evaluate(_replica_id()) v_sum = variable_scope.variable( 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, @@ -605,29 +553,25 @@ class MirroredStrategyVariableCreationTest(test.TestCase): 4.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.MEAN) - self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) - self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) - updates = [v_sum.assign_add(2.0 + device_id), - v_mean.assign(6.0 * device_id)] - all_v_sum[device_id] = v_sum - all_v_mean[device_id] = v_mean + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + self.assertTrue(isinstance(v_mean, values.ReplicaLocalVariable)) + updates = [v_sum.assign_add(2.0 + replica_id), + v_mean.assign(6.0 * replica_id)] + all_v_sum[replica_id] = v_sum + all_v_mean[replica_id] = v_mean c_sum = v_sum.get() c_mean = v_mean.get() - components_sum[device_id] = c_sum - components_mean[device_id] = c_mean + components_sum[replica_id] = c_sum + components_mean[replica_id] = c_mean self.assertIsNot(v_sum, c_sum) self.assertIsNot(v_mean, c_mean) return updates, v_sum, v_mean, c_sum, c_mean - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - # Create "sum" and "mean" versions of TowerLocalVariables. + with distribution.scope(): + # Create "sum" and "mean" versions of ReplicaLocalVariables. ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( - dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False)) - # Should see the same wrapping instance in all towers. + distribution.extended.call_for_each_replica(model_fn)) + # Should see the same wrapping instance in all replicas. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) self.assertIs(all_v_sum[0], all_v_sum[1]) @@ -641,10 +585,10 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Apply updates self.evaluate(variables.global_variables_initializer()) - self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + self.evaluate([y for x in ret_ops for y in distribution.unwrap(x)]) expected_sum = 0.0 expected_mean = 0.0 - for i, d in enumerate(dist.worker_devices): + for i, d in enumerate(distribution.extended.worker_devices): # Should see different values on different devices. v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) @@ -654,221 +598,235 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected = i * 6.0 self.assertEqual(expected, v_mean_value) expected_mean += expected - expected_mean /= len(dist.worker_devices) + expected_mean /= len(distribution.extended.worker_devices) # Without get(device), should return the value you get by - # applying the reduction across all towers (whether you use + # applying the reduction across all replicas (whether you use # read_var(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate( + distribution.extended.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate( + distribution.extended.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + # TODO(priyag): Update this test to work in eager mode as well. + def testDynamicRnnVariables(self, distribution): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + # Two variables are created by the RNN layer. + self.assertEqual(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = distribution.unwrap(v) + self.assertStartsWith(v1._op.name, "replica_1/") + + def testReplicaLocalVariableUpdate(self, distribution): + def model_fn(): + v_sum = variable_scope.variable( + 1.0, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) + return v_sum + + def update(var, value): + return var.assign(value) + + with distribution.scope(): + ret_v_sum = distribution.extended.call_for_each_replica(model_fn) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values before running the update ops. + self.assertEqual(1.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + update_ops = distribution.extended.update( + ret_v_sum, update, args=(5.0,), group=False) + self.evaluate(update_ops) + # Assert that the aggregated value of the replica local vars is the sum + # of the individual values after running the update ops. + self.assertEqual(5.0, self.evaluate(ret_v_sum.get( + distribution.extended.worker_devices[0]).read_value())) + self.assertEqual(10.0, self.evaluate(ret_v_sum)) + + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"])) +class MirroredStrategyNameScopeTest(test.TestCase): # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. - def testNameScope(self): + def testNameScope(self, distribution): def model_fn(): with ops.name_scope("foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(1.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(2, len(result)) + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("main/foo/" + name + ":0", v0.name) - self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("main/foo/" + name + ":0", v0.name) + self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) - def testWithDefaultName(self): + def testWithDefaultName(self, distribution): def model_fn(): with ops.name_scope(None, "foo"): a = constant_op.constant(1.0, name="a") - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) + ds_context.get_replica_context().merge_call(lambda _: _) b = constant_op.constant(2.0, name="b") return a, b - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - self.assertEquals(2, len(result)) + with context.graph_mode(), distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertEqual(2, len(result)) for v, name in zip(result, ["a", "b"]): self.assertIsInstance(v, values.DistributedValues) - v0, v1 = dist.unwrap(v) - self.assertEquals("foo/" + name + ":0", v0.name) - self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + v0, v1 = distribution.unwrap(v) + self.assertEqual("foo/" + name + ":0", v0.name) + self.assertEqual("replica_1/foo/" + name + ":0", v1.name) # variable_scope.variable() respects name scopes when creating # variables. On the other hand variable_scope.get_variable() ignores name # scopes when creating variables. We test both methods of creating variables # to make sure that we have the same variable names in both cases. - def testNameScopeWithVariable(self): - def in_cross_tower(_): + def testNameScopeWithVariable(self, distribution): + def in_cross_replica(_): c = variable_scope.variable(1.0, name="c") return c def model_fn(): b = variable_scope.variable(1.0, name="b") with ops.name_scope("foo"): - c = distribution_strategy_context.get_tower_context().merge_call( - in_cross_tower) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.variable(1.0, name="a") - result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("main/a:0", a0.name) - self.assertEquals("main/a/replica_1:0", a1.name) - self.assertEquals("main/b:0", b0.name) - self.assertEquals("main/b/replica_1:0", b1.name) - self.assertEquals("main/foo/c:0", c0.name) - self.assertEquals("main/foo/c/replica_1:0", c1.name) - - def testNameScopeWithGetVariable(self): - def in_cross_tower(_): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("main/a:0", a0.name) + self.assertEqual("main/a/replica_1:0", a1.name) + self.assertEqual("main/b:0", b0.name) + self.assertEqual("main/b/replica_1:0", b1.name) + self.assertEqual("main/foo/c:0", c0.name) + self.assertEqual("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self, distribution): + def in_cross_replica(_): c = variable_scope.get_variable("c", [1]) return c def model_fn(): b = variable_scope.get_variable("b", [1]) with ops.name_scope("foo"): - c = distribution_strategy_context.get_tower_context().merge_call( - in_cross_tower) + c = ds_context.get_replica_context().merge_call(in_cross_replica) return b, c - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): + with context.graph_mode(), distribution.scope(): with ops.name_scope("main"): a = variable_scope.get_variable("a", [1]) - result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result = distribution.extended.call_for_each_replica(model_fn) result_b = result[0] result_c = result[1] self.assertIsInstance(result_b, values.DistributedValues) self.assertIsInstance(result_c, values.DistributedValues) - a0, a1 = dist.unwrap(a) - b0, b1 = dist.unwrap(result_b) - c0, c1 = dist.unwrap(result_c) - self.assertEquals("a:0", a0.name) - self.assertEquals("a/replica_1:0", a1.name) - self.assertEquals("b:0", b0.name) - self.assertEquals("b/replica_1:0", b1.name) - self.assertEquals("c:0", c0.name) - self.assertEquals("c/replica_1:0", c1.name) - - def testDynamicRnnVariables(self): + a0, a1 = distribution.unwrap(a) + b0, b1 = distribution.unwrap(result_b) + c0, c1 = distribution.unwrap(result_c) + self.assertEqual("a:0", a0.name) + self.assertEqual("a/replica_1:0", a1.name) + self.assertEqual("b:0", b0.name) + self.assertEqual("b/replica_1:0", b1.name) + self.assertEqual("c:0", c0.name) + self.assertEqual("c/replica_1:0", c1.name) + + +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2), + combinations.NamedDistribution( + "CoreMirrored3Devices", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), + required_gpus=2) + ], + mode=["graph", "eager"])) +class MirroredThreeDeviceDistributionTest( + strategy_test_lib.DistributionTestBase, + parameterized.TestCase): + + def testThreeDevices(self, distribution): def model_fn(): - inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) - cell_fw = rnn_cell_impl.LSTMCell(300) - cell_bw = rnn_cell_impl.LSTMCell(300) - (outputs, _) = rnn.bidirectional_dynamic_rnn( - cell_fw, - cell_bw, - inputs, - dtype=dtypes.float32) - return outputs - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with context.graph_mode(), dist.scope(): - result = dist.call_for_each_tower(model_fn, run_concurrently=False) - # Two variables are created by the RNN layer. - self.assertEquals(2, len(result)) - for v in result: - self.assertIsInstance(v, values.DistributedValues) - _, v1 = dist.unwrap(v) - self.assertStartsWith(v1.name, "tower_1/") - - @test_util.run_in_graph_and_eager_modes(config=config) - def testTowerLocalVariableUpdate(self): - with context.graph_mode(): - - def model_fn(): - v_sum = variable_scope.variable( - 1.0, - synchronization=variable_scope.VariableSynchronization.ON_READ, - aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) - return v_sum - - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:GPU:1"]) - - def update(var, value): - return var.assign(value) - - with dist.scope(): - ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) - update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False) - - # Initialize variables. - self.evaluate(variables.global_variables_initializer()) - # Assert that the aggregated value of the tower local vars is the sum of - # the individual values before running the update ops. - self.assertEquals(1.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(2.0, self.evaluate(ret_v_sum)) + v = variable_scope.variable(1.0, name="foo") + ds_context.get_replica_context().merge_call(lambda _: _) + return v - # Apply updates. - self.evaluate(update_ops) - # Assert that the aggregated value of the tower local vars is the sum of - # the individual values after running the update ops. - self.assertEquals(5.0, self.evaluate( - ret_v_sum.get(dist._devices[0]).read_value())) - self.assertEquals(10.0, self.evaluate(ret_v_sum)) + with distribution.scope(): + result = distribution.extended.call_for_each_replica(model_fn) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEqual("foo:0", result.name) +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) class MirroredVariableUpdateTest(test.TestCase): # The following tests check assign, assign_add and assign_sub on Mirrored - # variables in tower and cross tower context. - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Enough GPUs not available for this test in eager mode.") + # variables in replica and cross replica context. - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarTowerContextWithoutAggregationType(self): + def testAssignMirroredVarReplicaContextWithoutAggregationType(self, + distribution): # Test that we always have an aggregation type set on the mirrored variable - # if we assign to it in tower mode. - self._skip_eager_if_gpus_less_than(1) + # if we assign to it in replica mode. def var_fn(): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -877,24 +835,20 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "You must specify an aggregation method to update a " - "MirroredVariable in Tower Context."): - self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + "MirroredVariable in Replica Context."): + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarTowerContextWithSum(self): - # Test that we don't reduce a non-per-device value with the "sum" + def testAssignMirroredVarReplicaContextWithSum(self, distribution): + # Test that we don't reduce a non-per-replica value with the "sum" # aggregation type. - self._skip_eager_if_gpus_less_than(1) def var_fn(): v = variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) @@ -903,225 +857,184 @@ class MirroredVariableUpdateTest(test.TestCase): with self.assertRaisesRegexp( ValueError, "A non-DistributedValues value 5.0 cannot be reduced " - "with the given aggregation VariableAggregation.SUM."): - self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn))) + "with the given reduce op ReduceOp.SUM."): + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarCrossTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) - self.assertEquals(6.0, mirrored_var_result) + self.assertEqual(6.0, mirrored_var_result) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_tower_context().tower_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign(value) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(0.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(0.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignMirroredVarTowerContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(5.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarCrossTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(1.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) # read_value == True mirrored_var_result = self.evaluate( mirrored_var.assign_add(6.0, read_value=True)) - self.assertEquals(7.0, mirrored_var_result) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(7.0, mirrored_var_result) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) # read_value == False self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_tower_context().tower_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_add(value) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(1.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(1.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignAddMirroredVarTowerContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(1.0, self.evaluate(mirrored_var)) + self.assertEqual(1.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_add(5.0) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(6.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(6.0, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarCrossTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarCrossDeviceContext(self, distribution): def var_fn(): return variable_scope.variable(5.0, name="foo") - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) - self.assertEquals(3.0, mirrored_var_result) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) - self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEqual(3.0, mirrored_var_result) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEqual(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarTowerContext(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContext(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): value = math_ops.cast( - distribution_strategy_context.get_tower_context().tower_id, + ds_context.get_replica_context().replica_id_in_sync_group, mirrored_var.dtype) return mirrored_var.assign_sub(value) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(4.5, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.5, self.evaluate(mirrored_var)) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignSubMirroredVarTowerContextWithSingleValue(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): def var_fn(): return variable_scope.variable( 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) - self.assertEquals(5.0, self.evaluate(mirrored_var)) + self.assertEqual(5.0, self.evaluate(mirrored_var)) def model_fn(): return mirrored_var.assign_sub(1.0) - self.evaluate(dist.unwrap(dist.call_for_each_tower( - model_fn, run_concurrently=False))) - self.assertEquals(4.0, self.evaluate(mirrored_var)) + self.evaluate(distribution.unwrap( + distribution.extended.call_for_each_replica(model_fn))) + self.assertEqual(4.0, self.evaluate(mirrored_var)) -class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredAndReplicaLocalVariableInitializerTest(test.TestCase): - def testAssignMirroredVarInitializer(self): + def testAssignMirroredVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1129,17 +1042,14 @@ class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): v = variable_scope.variable(1.0, name="foo") return v - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - mirrored_var = dist.call_for_each_tower(var_fn) + with distribution.scope(): + mirrored_var = distribution.extended.call_for_each_replica(var_fn) self.assertIsInstance(mirrored_var, values.MirroredVariable) self.assertFalse(self.evaluate(mirrored_var.is_initialized())) self.evaluate(mirrored_var.initializer) self.assertTrue(self.evaluate(mirrored_var.is_initialized())) - def testAssignTowerLocalVarInitializer(self): + def testAssignReplicaLocalVarInitializer(self, distribution): # This test is not eager compatible since in eager variables are initialized # upon construction instead of once the initialization op is run. with context.graph_mode(): @@ -1148,31 +1058,27 @@ class MirroredAndTowerLocalVariableInitializerTest(test.TestCase): 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, aggregation=variable_scope.VariableAggregation.SUM) - self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + self.assertTrue(isinstance(v_sum, values.ReplicaLocalVariable)) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - tower_local_var = dist.call_for_each_tower(model_fn) - self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) - self.assertFalse(self.evaluate(tower_local_var.is_initialized())) - self.evaluate(tower_local_var.initializer) - self.assertTrue(self.evaluate(tower_local_var.is_initialized())) - + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica( + model_fn) + self.assertTrue(isinstance(replica_local_var, + values.ReplicaLocalVariable)) + self.assertFalse(self.evaluate(replica_local_var.is_initialized())) + self.evaluate(replica_local_var.initializer) + self.assertTrue(self.evaluate(replica_local_var.is_initialized())) -class TowerLocalVariableAssignTest(test.TestCase): - config = config_pb2.ConfigProto() - config.allow_soft_placement = True - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class ReplicaLocalVariableAssignTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignTowerLocalVarSumAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarSumAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1180,30 +1086,27 @@ class TowerLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.SUM) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - tower_local_var = dist.call_for_each_tower(model_fn, - run_concurrently=False) - self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) + self.assertTrue(isinstance(replica_local_var, + values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) - # Each tower has a value of 1.0 assigned to it in tower context. + # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the SUM of each of - # values on each of the towers. - self.assertEqual(2.0, self.evaluate(dist.read_var(tower_local_var))) - # Assigning 6.0 in cross tower context will assign a value of - # 6.0/num_towers to each tower. - tlv_ops = tower_local_var.assign(6.0) + # values on each of the replicas. + self.assertEqual(2.0, self.evaluate( + distribution.read_var(replica_local_var))) + # Assigning 6.0 in cross replica context will assign a value of + # 6.0/num_replicas to each replica. + tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) - # On reading the tower local var we should get the assigned value back. - # The value on all the towers are added before being returned by + # On reading the replica local var we should get the assigned value back. + # The value on all the replicas are added before being returned by # `read_var`. - self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) - @test_util.run_in_graph_and_eager_modes(config=config) - def testAssignTowerLocalVarMeanAggregation(self): - self._skip_eager_if_gpus_less_than(1) + def testAssignReplicaLocalVarMeanAggregation(self, distribution): def model_fn(): v_sum = variable_scope.variable( 1.0, @@ -1211,23 +1114,22 @@ class TowerLocalVariableAssignTest(test.TestCase): aggregation=variable_scope.VariableAggregation.MEAN) return v_sum - dist = mirrored_strategy.MirroredStrategy( - ["/device:GPU:0", "/device:CPU:0"]) - - with dist.scope(): - tower_local_var = dist.call_for_each_tower(model_fn, - run_concurrently=False) - self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable)) + with distribution.scope(): + replica_local_var = distribution.extended.call_for_each_replica(model_fn) + self.assertTrue(isinstance(replica_local_var, + values.ReplicaLocalVariable)) self.evaluate(variables.global_variables_initializer()) - # Each tower has a value of 1.0 assigned to it in tower context. + # Each replica has a value of 1.0 assigned to it in replica context. # When we read the value using `read_var` we should see the MEAN of values - # on all towers which is the value assigned in tower context. - self.assertEqual(1.0, self.evaluate(dist.read_var(tower_local_var))) - tlv_ops = tower_local_var.assign(6.0) + # on all replicas which is the value assigned in replica context. + self.assertEqual(1.0, self.evaluate( + distribution.read_var(replica_local_var))) + tlv_ops = replica_local_var.assign(6.0) self.evaluate(tlv_ops) - # On reading the tower local var we should get the MEAN of all values + # On reading the replica local var we should get the MEAN of all values # which is equal to the value assigned. - self.assertEqual(6.0, self.evaluate(dist.read_var(tower_local_var))) + self.assertEqual(6.0, self.evaluate( + distribution.read_var(replica_local_var))) class MockModel(object): @@ -1245,25 +1147,41 @@ class MockModel(object): return x -class MirroredStrategyDefunTest(test.TestCase): +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name="") + self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones", + bias_initializer="ones") - def _skip_eager_if_gpus_less_than(self, num_gpus): - if context.num_gpus() < num_gpus and context.executing_eagerly(): - self.skipTest("Not enough GPUs available for this test in eager mode.") + def call(self, inputs, training=True): + inputs = array_ops.ones([1, 10]) + return self.fc(inputs) - def _call_and_check(self, model_fn, inputs, expected_result, defuns, - two_variables=False): + +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class MirroredStrategyDefunTest(test.TestCase): + + def _call_and_check(self, distribution, model_fn, inputs, expected_result, + defuns, two_variables=False): cpu_dev = device_util.canonicalize("CPU:0") gpu_dev = device_util.canonicalize("GPU:0") devices = [cpu_dev, gpu_dev] - dist = mirrored_strategy.MirroredStrategy(devices) - with dist.scope(): + with distribution.scope(): mock_model = MockModel(two_variables) self.evaluate(variables.global_variables_initializer()) - result = dist.call_for_each_tower(model_fn, mock_model, *inputs, - run_concurrently=False) + result = distribution.extended.call_for_each_replica( + model_fn, args=[mock_model] + inputs) for device in devices: device_result = values.select_device(device, result) device_expected_result = values.select_device(device, expected_result) @@ -1275,18 +1193,15 @@ class MirroredStrategyDefunTest(test.TestCase): # call_for_each has one trace per device. To check that the expected set # of variables was accessed on each trace, we first retrieve each # device-specific graph function. - per_device_graph_functions = dist.call_for_each_tower( - defun.get_concrete_function, - mock_model, *inputs, run_concurrently=False) + per_replica_graph_functions = ( + distribution.extended.call_for_each_replica( + defun.get_concrete_function, args=[mock_model] + inputs)) for device in devices: - graph_function = per_device_graph_functions.get(device=device) + graph_function = per_replica_graph_functions.get(device=device) self.assertEqual(set(mock_model.variables), set(graph_function.graph.variables)) - @test_util.run_in_graph_and_eager_modes() - def testVariableInDefun(self): - self._skip_eager_if_gpus_less_than(1) - + def testVariableInDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1294,12 +1209,9 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return times_two(mock_model) - self._call_and_check(model_fn, [], 2.5, [times_two]) - - @test_util.run_in_graph_and_eager_modes() - def testVariableInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 2.5, [times_two]) + def testVariableInNestedDefun(self, distribution): @function.defun def times_two(mock_model): return mock_model() @@ -1311,12 +1223,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return two_x_plus_one(mock_model) - self._call_and_check(model_fn, [], 3.5, [times_two, two_x_plus_one]) - - @test_util.run_in_graph_and_eager_modes() - def testTwoVariablesInNestedDefun(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 3.5, + [times_two, two_x_plus_one]) + def testTwoVariablesInNestedDefun(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1328,12 +1238,10 @@ class MirroredStrategyDefunTest(test.TestCase): def model_fn(mock_model): return fn2(mock_model) - self._call_and_check(model_fn, [], 5.5, [fn1, fn2], two_variables=True) - - @test_util.run_in_graph_and_eager_modes() - def testGradientTapeOverNestedDefuns(self): - self._skip_eager_if_gpus_less_than(1) + self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2], + two_variables=True) + def testGradientTapeOverNestedDefuns(self, distribution): @function.defun def fn1(mock_model): return mock_model() @@ -1349,38 +1257,122 @@ class MirroredStrategyDefunTest(test.TestCase): [v.get() for v in mock_model.variables]) return grads - self._call_and_check(model_fn, [], [2.0, 1.0], [fn1, fn2], + self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2], two_variables=True) - @test_util.run_in_graph_and_eager_modes() - def testPassPerDevice(self): - self._skip_eager_if_gpus_less_than(1) - + def testPassPerReplica(self, distribution): @function.defun def fn1(mock_model, factor): return mock_model(factor) - factors = values.PerDevice({"CPU:0": 5.0, "GPU:0": 3.0}) - expected_result = values.PerDevice({"CPU:0": 5.0 * 1.25, - "GPU:0": 3.0 * 1.25}) - self._call_and_check(fn1, [factors], expected_result, [fn1]) + factors = values.PerReplica({"CPU:0": 5.0, "GPU:0": 3.0}) + expected_result = values.PerReplica({"CPU:0": 5.0 * 1.25, + "GPU:0": 3.0 * 1.25}) + self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) + + def testTrain(self, distribution): + with distribution.scope(): + mock_model = MiniModel() + mock_model.call = function.defun(mock_model.call) + + def loss_fn(ctx): + del ctx + return mock_model(array_ops.ones([1, 10])) + + gradients_fn = backprop.implicit_grad(loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + grads_and_vars = distribution.extended.call_for_each_replica( + gradients_fn, args=(None,)) + + optimizer = gradient_descent.GradientDescentOptimizer(0.25) + update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(update_ops) + updated_var_values = self.evaluate(mock_model.variables) + # All variables start at 1.0 and get two updates of 0.25. + self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0]) + self.assertAllEqual([0.5], updated_var_values[1]) + + +@combinations.generate( + combinations.combine( + distribution=[ + combinations.NamedDistribution( + "Mirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker= + context.num_gpus()), + required_gpus=1), + combinations.NamedDistribution( + "CoreMirrored", + # pylint: disable=g-long-lambda + lambda: mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()), + required_gpus=1) + ], + mode=["graph"])) class MultiWorkerMirroredStrategyTest( multi_worker_test_base.MultiWorkerTestBase, strategy_test_lib.DistributionTestBase): - def _get_distribution_strategy(self): + def _configure_distribution_strategy(self, distribution): cluster_spec = server_lib.ClusterSpec({ "worker": ["/job:worker/task:0", "/job:worker/task:1"] }) - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure(cluster_spec=cluster_spec) - return strategy - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy(), - learning_rate=0.05) + distribution.configure(cluster_spec=cluster_spec) + + def test_num_replicas_in_sync(self, distribution): + self._configure_distribution_strategy(distribution) + # We calculate the total number of gpus across the workers(2) specified in + # the cluster spec. + self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) + + def testMinimizeLossGraph(self, distribution): + self._configure_distribution_strategy(distribution) + self._test_minimize_loss_graph(distribution, learning_rate=0.05) + + def testDeviceScope(self, distribution): + """Test the device scope of multi-worker MirroredStrategy.""" + self._configure_distribution_strategy(distribution) + with distribution.scope(): + a = constant_op.constant(1.) + with ops.device("/cpu:0"): + b = constant_op.constant(1.) + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") + + def testMakeInputFnIterator(self, distribution): + self._configure_distribution_strategy(distribution) + dataset_fn = lambda: dataset_ops.Dataset.range(100) + num_gpus = context.num_gpus() + num_workers = 2 + + expected_values = [[i+j for j in range(num_gpus)] * num_workers + for i in range(0, 100, num_gpus)] + + with context.graph_mode(), self.cached_session() as sess: + # `expected_input_pipeline_id` is None because the input_fn will be called + # multiple times, each with a different input_pipeline_id. + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_workers*num_gpus, + expected_num_input_pipelines=num_workers, + expected_input_pipeline_id=None) + iterator = distribution.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, distribution.extended.worker_devices, expected_values, sess) + + def testUpdateConfigProto(self, distribution): + distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) class MultiWorkerMirroredStrategyTestWithChief( @@ -1400,6 +1392,19 @@ class MultiWorkerMirroredStrategyTestWithChief( strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) + def testMinimizeLossGraphCoreMirroredStrategy(self): + strategy = mirrored_strategy.CoreMirroredStrategy( + mirrored_strategy.all_local_devices()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + +def _replica_id(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if not isinstance(replica_id, ops.Tensor): + replica_id = constant_op.constant(replica_id) + return replica_id + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py deleted file mode 100644 index 969e1269560e52736d05e6b14ce320d9bd4fcac0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for class MirroredStrategy.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import strategy_test_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context - - -class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) - - def testMinimizeLossEager(self): - self._test_minimize_loss_eager(self._get_distribution_strategy()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) - - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - - def testTowerId(self): - self._test_tower_id(self._get_distribution_strategy()) - - @test_util.run_in_graph_and_eager_modes - def testCallAndMergeExceptions(self): - self._test_call_and_merge_exceptions(self._get_distribution_strategy()) - - -class VariableCreatorStackTest(test.TestCase): - - def testCreatorStacksAreThreadLocal(self): - devices = ["/device:CPU:0", "/device:GPU:0"] - dist = mirrored_strategy.MirroredStrategy(devices) - - def model_fn(device_id): - assert isinstance(device_id, int) - - def thread_creator_fn(next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + ":thread_" + str(device_id) - - with variable_scope.variable_creator_scope(thread_creator_fn): - # Create a variable in this scope. - v = variable_scope.variable(1.0) - - # This will pause the current thread, and execute the other thread. - distribution_strategy_context.get_tower_context().merge_call( - lambda _: _) - return v - - def main_thread_creator(next_creator, *args, **kwargs): - # We are not using the underlying next_creator for test purposes. - del next_creator, args, kwargs - return "main_thread" - - with context.graph_mode(), \ - dist.scope(), \ - variable_scope.variable_creator_scope(main_thread_creator): - result = dist.call_for_each_tower(model_fn, dist.worker_device_index) - result = dist.unwrap(result) - expected = ["main_thread:thread_0", "main_thread:thread_1"] - self.assertEquals(expected, result) - - -class MultiWorkerMirroredStrategyTest(test.TestCase): - - def testDeviceScope(self): - """Test the device scope of multi-worker MirroredStrategy.""" - with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) - strategy.configure( - cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) - with strategy.scope(): - a = constant_op.constant(1.) - with ops.device("/cpu:0"): - b = constant_op.constant(1.) - self.assertEqual(a.device, "/job:worker/task:0") - self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 7644acedc99361d7287a91832d76bc68cbc6ac0a..17b7ab74f63f42e1ee14a82d3bffdd1df9b25857 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,6 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") + session.run(step_callable._iterator.initializer) # pylint: disable=protected-access self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f13e9153ea7a951dd722c4549882c97e79b57fe --- /dev/null +++ b/tensorflow/contrib/distribute/python/moving_averages_test.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for training.moving_averages when using a DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages + + +all_combinations = combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph"]) + + +class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase): + + @combinations.generate(all_combinations) + def testReplicaModeWithoutZeroDebias(self, distribution): + replica_id = [0] + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0 + replica_id[0], 2.0 - replica_id[0]]) + replica_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + # Mean of val across calls to replica_fn(). + average_val = [1.0 + 0.5 * (replica_id[0] - 1), + 2.0 - 0.5 * (replica_id[0] - 1)] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testReplicaMode(self, distribution): + replica_id = [0] + + def replica_fn(): + var = variables.Variable([0.0, 0.0]) + val = constant_op.constant([1.0 + replica_id[0], 2.0 - replica_id[0]]) + replica_id[0] += 1 + decay = 0.25 + assign = moving_averages.assign_moving_average(var, val, decay) + return var, assign.op + + with distribution.scope(), self.cached_session() as sess: + var, assign_op = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(distribution.unwrap(assign_op)) + # Mean of val across calls to replica_fn(). + average_val = [1.0 + 0.5 * (replica_id[0] - 1), + 2.0 - 0.5 * (replica_id[0] - 1)] + self.assertAllClose(average_val, var.eval()) + + @combinations.generate(all_combinations) + def testCrossDeviceWithoutZeroDebias(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([10.0, 11.0]) + val = constant_op.constant([1.0, 2.0]) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(assign) + average_val = [1.0, 2.0] + val_weight = 1.0 - 0.25 + self.assertAllClose( + [10.0 * 0.25 + average_val[0] * val_weight, + 11.0 * 0.25 + average_val[1] * val_weight], + var.eval()) + # Also try assign.op. + sess.run(assign.op) + orig_weight = 0.25 * 0.25 + val_weight = 1.0 - orig_weight + self.assertAllClose( + [10.0 * orig_weight + average_val[0] * val_weight, + 11.0 * orig_weight + average_val[1] * val_weight], + var.eval()) + + @combinations.generate(all_combinations) + def testCrossDevice(self, distribution): + with distribution.scope(), self.cached_session() as sess: + var = variables.Variable([0.0, 0.0]) + val = array_ops.placeholder(dtypes.float32) + decay = 0.25 + # NOTE(josh11b): We currently generate an error if val is a PerReplica + # value. + assign = moving_averages.assign_moving_average(var, val, decay) + + variables.global_variables_initializer().run() + self.assertAllClose([0.0, 0.0], var.eval()) + sess.run(assign, feed_dict={val: [1.0, 2.0]}) + self.assertAllClose([1.0, 2.0], var.eval()) + + # Also try assign.op. + sess.run(assign.op, feed_dict={val: [10.0, 0.0]}) + self.assertAllClose( + [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0), + (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)], + var.eval()) + + @combinations.generate(all_combinations) + def testAssignVariable(self, distribution): + + def replica_fn(): + var = variables.Variable([10.0, 11.0]) + # Here we expect to check the case when input value are variable. + val = variables.Variable([1., 2.]) + decay = 0.25 + assign = moving_averages.assign_moving_average( + var, val, decay, zero_debias=False) + return var, assign + + with distribution.scope(), self.cached_session() as sess: + var, assign = distribution.call_for_each_replica(replica_fn) + variables.global_variables_initializer().run() + self.assertAllClose([10.0, 11.0], var.eval()) + sess.run(distribution.unwrap(assign)) + self.assertAllClose( + [10 * 0.25 + 1. * (1 - 0.25), 11 * 0.25 + 2. * (1 - 0.25)], + var.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 9f92ba7dde5fc2798201cef2238bcc4b20b698a8..147c9b83f866fd364ea23cf7988692a7b5f61b9c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy +import json +import os import threading import numpy as np @@ -39,7 +42,6 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib - ASSIGNED_PORTS = set() lock = threading.Lock() @@ -207,12 +209,10 @@ class MultiWorkerTestBase(test.TestCase): self._lock = threading.Lock() @contextlib.contextmanager - def test_session(self, graph=None, config=None, target=None): + def session(self, graph=None, config=None, target=None): """Create a test session with master target set to the testing cluster. - This overrides the base class' method, removes arguments that are not needed - by the multi-node case and creates a test session that connects to the local - testing cluster. + Creates a test session that connects to the local testing cluster. Args: graph: Optional graph to use during the returned session. @@ -224,9 +224,44 @@ class MultiWorkerTestBase(test.TestCase): A Session object that should be used as a context manager to surround the graph building and execution code in a test case. """ - if self.id().endswith('.test_session'): - self.skipTest('Not a test.') + config = self._create_config(config) + + if target is None: + target = self._default_target + with session.Session(graph=graph, config=config, target=target) as sess: + yield sess + + @contextlib.contextmanager + # TODO(b/117573461): Overwrite self.evaluate() to use this function. + def cached_session(self, graph=None, config=None, target=None): + """Create a test session with master target set to the testing cluster. + + Creates a test session that connects to the local testing cluster. + The session is only created once per test and then reused. + + Args: + graph: Optional graph to use during the returned session. + config: An optional config_pb2.ConfigProto to use to configure the + session. + target: the target of session to connect to. + + Yields: + A Session object that should be used as a context manager to surround + the graph building and execution code in a test case. Note that the + session will live until the end of the test. + """ + config = self._create_config(config) + if target is None: + target = self._default_target + if getattr(self._thread_local, 'cached_session', None) is None: + self._thread_local.cached_session = session.Session( + graph=None, config=config, target=target) + sess = self._thread_local.cached_session + with sess.graph.as_default(), sess.as_default(): + yield sess + + def _create_config(self, config): if config is None: config = config_pb2.ConfigProto(allow_soft_placement=True) else: @@ -237,18 +272,7 @@ class MultiWorkerTestBase(test.TestCase): config.graph_options.rewrite_options.constant_folding = ( rewriter_config_pb2.RewriterConfig.OFF) - if target is None: - target = self._default_target - if graph is None: - if getattr(self._thread_local, 'cached_session', None) is None: - self._thread_local.cached_session = session.Session( - graph=None, config=config, target=target) - sess = self._thread_local.cached_session - with sess.graph.as_default(), sess.as_default(): - yield sess - else: - with session.Session(graph=graph, config=config, target=target) as sess: - yield sess + return config def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, **kwargs): @@ -281,3 +305,101 @@ class MultiWorkerTestBase(test.TestCase): for t in threads: t.join() self.assertEqual(self._result, len(threads)) + + +class MockOsEnv(collections.Mapping): + """A class that allows per-thread TF_CONFIG.""" + + def __init__(self, *args): + self._dict = dict() + self._thread_local = threading.local() + super(MockOsEnv, self).__init__(*args) + + def get(self, key, default=None): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.get(self._thread_local.dict, key, default) + else: + return dict.get(self._dict, key, default) + + def __getitem__(self, key): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__getitem__(self._thread_local.dict, key) + else: + return dict.__getitem__(self._dict, key) + + def __setitem__(self, key, val): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + if key == 'TF_CONFIG': + return dict.__setitem__(self._thread_local.dict, key, val) + else: + return dict.__setitem__(self._dict, key, val) + + def __iter__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + for x in self._thread_local.dict.items(): + yield x + for x in self._dict.items(): + yield x + + def __len__(self): + if not hasattr(self._thread_local, 'dict'): + self._thread_local.dict = dict() + return self._thread_local.dict.__len__() + self._dict.__len__() + + +class IndependentWorkerTestBase(test.TestCase): + """Testing infra for independent workers.""" + + def setUp(self): + self._mock_os_env = MockOsEnv() + self._mock_context = test.mock.patch.object(os, 'environ', + self._mock_os_env) + super(IndependentWorkerTestBase, self).setUp() + self._mock_context.__enter__() + + def tearDown(self): + self._mock_context.__exit__(None, None, None) + super(IndependentWorkerTestBase, self).tearDown() + + def _task_thread(self, task_fn, tf_config, *args, **kwargs): + os.environ['TF_CONFIG'] = json.dumps(tf_config) + task_fn(*args, **kwargs) + + def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, + *args, **kwargs): + if task_type: + tf_config = { + 'cluster': cluster_spec, + 'task': { + 'type': task_type, + 'index': task_id + } + } + else: + tf_config = { + 'cluster': cluster_spec, + } + t = threading.Thread( + target=self._task_thread, + args=(task_fn, tf_config) + args, + kwargs=kwargs) + t.start() + return t + + def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, + **kwargs): + # The task_fn should create std_server by itself. + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, + *args, **kwargs) + threads[task_type].append(t) + return threads diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index f5259190485e701c190beb49220caff743f8fdcb..fdbfba4e04358451a46b23ef250dc7c534c855a0 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -20,14 +20,14 @@ from __future__ import print_function import six -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest @@ -40,10 +40,16 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device, prefetch_on_device=None): - super(OneDeviceStrategy, self).__init__() + def __init__(self, device): + super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) + + +class OneDeviceExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of OneDeviceStrategy.""" + + def __init__(self, container_strategy, device): + super(OneDeviceExtended, self).__init__(container_strategy) self._device = device - self._prefetch_on_device = prefetch_on_device self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): @@ -54,25 +60,40 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): if isinstance(colocate_with, six.string_types): with ops.device(colocate_with): return next_creator(*args, **kwargs) - if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + if (isinstance(colocate_with, (list, tuple)) and len(colocate_with) == 1 and isinstance(colocate_with[0], six.string_types)): with ops.device(colocate_with[0]): return next_creator(*args, **kwargs) with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset_fn): - return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), [self._device], - self._prefetch_on_device) - - def _broadcast(self, tensor, destinations): + def _make_dataset_iterator(self, dataset): + """Make iterator from dataset without splitting the batch.""" + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.DatasetIterator(dataset, worker_device_pairs) + + def _distribute_dataset(self, dataset_fn): + return values.PerReplicaDataset( + self._call_dataset_fn(dataset_fn), [self._device]) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + worker = device_util.canonicalize("/device:CPU:0") + worker_device_pairs = [(worker, [self._device])] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, + [distribute_lib.InputContext()]) + + def _broadcast_to(self, tensor, destinations): del destinations return tensor # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): + def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, + initial_loop_values=None): if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) @@ -84,7 +105,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): fn_inputs = iterator.get_next() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) with ops.control_dependencies([fn_result]): return [i + 1] + flat_last_step_outputs @@ -117,86 +138,82 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_tower(self, fn, *args, **kwargs): - # We don't run `fn` in multiple threads in OneDeviceStrategy. - kwargs.pop("run_concurrently", None) - with ops.device(self._device), _OneDeviceTowerContext(self): + def _call_for_each_replica(self, fn, args, kwargs): + strategy = self._container_strategy() + with ops.device(self._device), _OneDeviceReplicaContext(strategy): return fn(*args, **kwargs) - def map(self, map_over, fn, *args, **kwargs): - with ops.device(self._device): - return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - - def _reduce(self, aggregation, value, destinations): - del destinations - if not isinstance(value, values.MapOutput): - return value - l = value.get() - assert l - with ops.device(self._device): - if aggregation == vs.VariableAggregation.SUM: - return math_ops.add_n(l) - elif aggregation == vs.VariableAggregation.MEAN: - return math_ops.add_n(l) / len(l) - else: - assert False + def _reduce_to(self, reduce_op, value, destinations): + del reduce_op, destinations + return value - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): # The implementations of _update() and _update_non_slot() are identical # except _update() passes `var` as the first argument to `fn()`. - return self._update_non_slot(var, options, fn, var, *args, **kwargs) + return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.device(self._device), distribute_lib.UpdateContext(self._device): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) - def read_var(self, tower_local_var): - """Read the aggregate value of a tower-local variable.""" - return array_ops.identity(tower_local_var) + def read_var(self, replica_local_var): + """Read the aggregate value of a replica-local variable.""" + return array_ops.identity(replica_local_var) def _unwrap(self, value): - return [value] + return (value,) def value_container(self, value): return value @property - def is_single_tower(self): - return True - - @property - def num_towers(self): + def _num_replicas_in_sync(self): return 1 @property def worker_devices(self): - return [self._device] + return (self._device,) @property def parameter_devices(self): - return [self._device] + return (self._device,) def non_slot_devices(self, var_list): del var_list - return [self._device] + return (self._device,) - def _worker_device_index(self): - return 0 + @property + def experimental_should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True -class _OneDeviceTowerContext(distribute_lib.TowerContext): +class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): + """ReplicaContext for OneDeviceStrategy.""" def __init__(self, distribution_strategy): - distribute_lib.TowerContext.__init__( - self, distribution_strategy, tower_id=0) + distribute_lib.ReplicaContext.__init__( + self, + distribution_strategy, + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) @property - def device(self): - return self._distribution_strategy.worker_devices[0] + def devices(self): + return self._distribution_strategy.extended.worker_devices diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 4fdc0f72e6745b7ef25c591157955f214e0b2c79..d46cd6f529e363f76bfa2b22339add63530cfde8 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import test_util @@ -35,19 +36,27 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def testMinimizeLossGraph(self): self._test_minimize_loss_graph(self._get_distribution_strategy()) - def testMapReduce(self): - self._test_map_reduce(self._get_distribution_strategy()) - - def testDeviceIndex(self): - self._test_device_index(self._get_distribution_strategy()) - - def testTowerId(self): - self._test_tower_id(self._get_distribution_strategy()) + def testReplicaId(self): + self._test_replica_id(self._get_distribution_strategy()) @test_util.run_in_graph_and_eager_modes def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + @test_util.run_in_graph_and_eager_modes + def testMakeInputFnIterator(self): + d = one_device_strategy.OneDeviceStrategy("/device:CPU:0") + dataset_fn = lambda: dataset_ops.Dataset.range(10) + expected_values = [[i] for i in range(10)] + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=1, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) + iterator = d.make_input_fn_iterator(input_fn) + self._test_input_fn_iterator( + iterator, d.extended.worker_devices, expected_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 6e9ba37a198fc8038c086d2672251adfac30fdcf..fa4705af7cb592119f56686d1f693a156f7b4b13 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -42,16 +42,20 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + ds = distribution.distribute_dataset(dataset_fn) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( - distribution.call_for_each_tower( - model_fn, iterator.get_next(), run_concurrently=layer.built))) + distribution.call_for_each_replica( + model_fn, args=(iterator.get_next(),)))) if not context.executing_eagerly(): with self.cached_session() as sess: + sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 6ddd91507bf86e8b0cf710a2340fd61abcdebe71..2c7766f95fbcb7b68a53ad0052f21485c763a1db 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +import copy + from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -30,8 +34,6 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import device_setter -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util import nest _LOCAL_CPU = "/device:CPU:0" @@ -61,16 +63,16 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): for a particular worker. Note that each graph and worker is independent. This means that while each worker will synchronously compute a single gradient update across all GPUs, updates between workers proceed asynchronously. - Operations that occur only on the first tower (such as incrementing the global - step), will occur on the first tower *of every worker*. + Operations that occur only on the first replica (such as incrementing the + global step), will occur on the first replica *of every worker*. - It is expected to call `call_for_each_tower(fn, *args, **kwargs)` for any - operations which potentially can be replicated across towers (i.e. multiple + It is expected to call `call_for_each_replica(fn, ...)` for any + operations which potentially can be replicated across replicas (i.e. multiple GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra caution needs to be taken: 1) Always use `tf.get_variable` instead of `tf.Variable` which is not able - to refer to the same variable on different towers. + to refer to the same variable on different replicas. 2) It is generally not recommended to open a device scope under the strategy's scope. A device scope (i.e. calling `tf.device`) will be merged with or @@ -94,13 +96,21 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): ValueError: if `cluster_spec` is given but `task_type` or `task_id` is not. """ - super(ParameterServerStrategy, self).__init__() + super(ParameterServerStrategy, self).__init__( + ParameterServerExtended(self, num_gpus_per_worker)) + + +class ParameterServerExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of ParameterServerStrategy.""" + + def __init__(self, container_strategy, num_gpus_per_worker): + super(ParameterServerExtended, self).__init__(container_strategy) self._num_gpus_per_worker = num_gpus_per_worker self._initialize_local(num_gpus_per_worker) # We typically don't need to do all-reduce in this strategy. - self._cross_tower_ops = ( - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + self._cross_device_ops = ( + cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps( reduce_to_device=_LOCAL_CPU)) def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec, @@ -108,10 +118,10 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): """Initialize devices for multiple workers. It creates variable devices and compute devices. Variables and operations - will be assigned to them respectively. We have one compute device per tower. - The variable device is a device function or device string. The default - variable device assigns variables to parameter servers in a round-robin - fashion. + will be assigned to them respectively. We have one compute device per + replica. The variable device is a device function or device string. The + default variable device assigns variables to parameter servers in a + round-robin fashion. Args: num_gpus_per_worker: number of local GPUs or GPUs per worker. @@ -132,17 +142,17 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id) # Define compute devices which is a list of device strings and one for each - # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, - # place operations on CPU. + # replica. When there are GPUs, replicate operations on these GPUs. + # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = [ + self._compute_devices = tuple( "%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus_per_worker) - ] + ) else: - self._compute_devices = [self._worker_device] + self._compute_devices = (self._worker_device,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) @@ -166,8 +176,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): # The `_parameter_devices` is needed for the `parameter_devices` property # and is a list of all variable devices. Here parameter devices are all # tasks of the "ps" job. - self._parameter_devices = map("/job:ps/task:{}".format, - range(num_ps_replicas)) + self._parameter_devices = tuple(map("/job:ps/task:{}".format, + range(num_ps_replicas))) # Add a default device so that ops without specified devices will not end up # on other workers. @@ -189,28 +199,29 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): def _initialize_local(self, num_gpus_per_worker): """Initialize internal devices for local training.""" + self._worker_device = device_util.canonicalize("/device:CPU:0") # Define compute devices which is a list of device strings and one for each - # tower. When there are GPUs, replicate operations on these GPUs. Otherwise, - # place operations on CPU. + # replica. When there are GPUs, replicate operations on these GPUs. + # Otherwise, place operations on CPU. if num_gpus_per_worker > 0: - self._compute_devices = list( + self._compute_devices = tuple( map("/device:GPU:{}".format, range(num_gpus_per_worker))) else: - self._compute_devices = [_LOCAL_CPU] + self._compute_devices = (_LOCAL_CPU,) - self._compute_devices = list( + self._compute_devices = tuple( map(device_util.resolve, self._compute_devices)) self._canonical_compute_device_set = set(self._compute_devices) # If there is only one GPU, put everything on that GPU. Otherwise, place # variables on CPU. if num_gpus_per_worker == 1: - assert len(list(self._compute_devices)) == 1 + assert len(self._compute_devices) == 1 self._variable_device = _LOCAL_GPU_0 - self._parameter_devices = [_LOCAL_GPU_0] + self._parameter_devices = (_LOCAL_GPU_0,) else: self._variable_device = _LOCAL_CPU - self._parameter_devices = [_LOCAL_CPU] + self._parameter_devices = (_LOCAL_CPU,) self._is_chief = True self._cluster_spec = None @@ -221,31 +232,68 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "ParameterServerStrategy with compute_devices = %r, " "variable_device = %r", self._compute_devices, self._variable_device) - def distribute_dataset(self, dataset_fn): + def _distribute_dataset(self, dataset_fn): """Distributes the dataset to each local GPU.""" - return values.PerDeviceDataset( + return values.PerReplicaDataset( self._call_dataset_fn(dataset_fn), self._compute_devices, True) - def _broadcast(self, tensor, destinations): - if not cross_tower_ops_lib.check_destinations(destinations): + def _make_dataset_iterator(self, dataset): + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.DatasetIterator(dataset, worker_device_pairs, + self._num_replicas_in_sync) + + def _make_input_fn_iterator( + self, + input_fn, + replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): + """Distributes the dataset to each local GPU.""" + if self._cluster_spec: + input_pipeline_id = multi_worker_util.id_in_cluster( + self._cluster_spec, self._task_type, self._task_id) + num_input_pipelines = multi_worker_util.worker_count( + self._cluster_spec, self._task_type) + else: + input_pipeline_id = 0 + num_input_pipelines = 1 + input_context = distribute_lib.InputContext( + num_input_pipelines=num_input_pipelines, + input_pipeline_id=input_pipeline_id, + num_replicas_in_sync=self._num_replicas_in_sync) + worker_device_pairs = [(self._worker_device, self._compute_devices)] + return values.InputFunctionIterator( + input_fn, worker_device_pairs, [input_context]) + + def _broadcast_to(self, tensor, destinations): + # This is both a fast path for Python constants, and a way to delay + # converting Python values to a tensor until we know what type it + # should be converted to. Otherwise we have trouble with: + # global_step.assign_add(1) + # since the `1` gets broadcast as an int32 but global_step is int64. + if isinstance(tensor, (float, int)): + return tensor + if not cross_device_ops_lib.check_destinations(destinations): destinations = self._compute_devices - return self._cross_tower_ops.broadcast(tensor, destinations) + return self._cross_device_ops.broadcast(tensor, destinations) + + def _allow_variable_partition(self): + return not context.executing_eagerly() # TODO(yuefengz): not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". def _create_variable(self, next_creator, *args, **kwargs): - if self.num_towers > 1: + if self._num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_TOWER + vs.VariableAggregation.ONLY_FIRST_REPLICA ): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) def var_creator(*args, **kwargs): + """Create an AggregatingVariable and fix up collections.""" # Record what collections this variable should be added to. collections = kwargs.pop("collections", None) if collections is None: @@ -287,39 +335,37 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): with ops.device(self._variable_device): return var_creator(*args, **kwargs) - def _call_for_each_tower(self, fn, *args, **kwargs): + def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access - return mirrored_strategy._call_for_each_tower(self, fn, *args, **kwargs) + return mirrored_strategy._call_for_each_replica( + self._container_strategy(), fn, args, kwargs) def _verify_destinations_not_different_worker(self, destinations): + if not self._cluster_spec: + return if destinations is None: return - for d in cross_tower_ops_lib.get_devices_from(destinations): + for d in cross_device_ops_lib.get_devices_from(destinations): d_spec = tf_device.DeviceSpec.from_string(d) if d_spec.job == self._task_type and d_spec.task != self._task_id: raise ValueError( "Cannot reduce to another worker: %r, current worker is %r" % (d, self._worker_device)) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): self._verify_destinations_not_different_worker(destinations) if not isinstance(value, values.DistributedValues): # pylint: disable=protected-access return mirrored_strategy._reduce_non_distributed_value( - self, aggregation, value, destinations) - if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: - return self.broadcast(value.get(self._compute_devices[0]), destinations) - return self._cross_tower_ops.reduce( - aggregation, value, destinations=destinations) - - def _batch_reduce(self, aggregation, value_destination_pairs): - if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: - return [self.broadcast(v.get(self._compute_devices[0]), d) - for v, d in value_destination_pairs] + self, reduce_op, value, destinations) + return self._cross_device_ops.reduce( + reduce_op, value, destinations=destinations) + + def _batch_reduce_to(self, reduce_op, value_destination_pairs): for _, destinations in value_destination_pairs: self._verify_destinations_not_different_worker(destinations) - return self._cross_tower_ops.batch_reduce(aggregation, - value_destination_pairs) + return self._cross_device_ops.batch_reduce(reduce_op, + value_destination_pairs) def _select_single_value(self, structured): """Select any single values in `structured`.""" @@ -333,9 +379,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): "You cannot update variable with a Mirrored object with multiple " "components %r when using ParameterServerStrategy. You must " "specify a single value or a Mirrored with a single value." % x) - elif isinstance(x, values.PerDevice): + elif isinstance(x, values.PerReplica): raise ValueError( - "You cannot update variable with a PerDevice object %r when using " + "You cannot update variable with a PerReplica object %r when using " "ParameterServerStrategy. You must specify a single value or a " "Mirrored with a single value" % x) else: @@ -343,30 +389,26 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): return nest.map_structure(_select_fn, structured) - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() if not isinstance(var, resource_variable_ops.ResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): result = fn(var, *self._select_single_value(args), **self._select_single_value(kwargs)) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) # TODO(yuefengz): does it need to call _select_single_value? - def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs): - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): with ops.device( colocate_with.device), distribute_lib.UpdateContext(colocate_with): result = fn(*args, **kwargs) - if should_group: + if group: return result else: return nest.map_structure(self._unwrap, result) @@ -375,22 +417,28 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if isinstance(val, values.DistributedValues): # Return in a deterministic order. if set(val.devices) == self._canonical_compute_device_set: - return [val.get(device=d) for d in self._compute_devices] - return [val.get(device=d) for d in sorted(val.devices)] - return [val] + return tuple(val.get(device=d) for d in self._compute_devices) + return tuple(val.get(device=d) for d in sorted(val.devices)) + return (val,) def value_container(self, val): - return values.value_container(val) + if (hasattr(val, "_aggregating_container") and + not isinstance(val, values.AggregatingVariable)): + wrapper = val._aggregating_container() # pylint: disable=protected-access + if wrapper is not None: + return wrapper + return val def read_var(self, var): - # No need to distinguish between normal variables and tower-local variables. + # No need to distinguish between normal variables and replica-local + # variables. return array_ops.identity(var) - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): """Configures the strategy class. The strategy object will be re-initialized if `cluster_spec` is given but @@ -421,44 +469,50 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): self._initialize_multi_worker(self._num_gpus_per_worker, self._cluster_spec, task_type, task_id) - if not session_config or not self._cluster_spec: - return + if session_config: + session_config.CopyFrom(self._update_config_proto(session_config)) + + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + if not self._cluster_spec: + updated_config.isolate_session_state = True + return updated_config - session_config.isolate_session_state = False + updated_config.isolate_session_state = False - assert self._cluster_spec assert self._task_type assert self._task_id is not None # The device filters prevent communication between workers. if self._task_type not in ["chief", "worker"]: - return - del session_config.device_filters[:] - session_config.device_filters.extend( + return updated_config + del updated_config.device_filters[:] + updated_config.device_filters.extend( ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) + return updated_config @property - def num_towers(self): + def _num_replicas_in_sync(self): return len(self._compute_devices) @property def worker_devices(self): - # Make a copy to prevent users from accidentally mutating our copy. - return list(self._compute_devices) + return self._compute_devices @property def parameter_devices(self): - return list(self._parameter_devices) + return self._parameter_devices def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) @property - def between_graph(self): + def experimental_between_graph(self): + # TODO(yuefengz): Should this return False in the local case? return True @property - def should_init(self): + def experimental_should_init(self): return self._is_chief @property @@ -468,3 +522,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): @property def should_save_summary(self): return self._is_chief + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return False diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 353d11a5831904abd43828f1d9d4abfc61aede60..83d7473666a65e438a1c0119d2a12bf54e53c8fc 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -25,22 +25,29 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy -from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_util -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import training_util CHIEF = run_config.TaskType.CHIEF @@ -48,6 +55,13 @@ WORKER = run_config.TaskType.WORKER PS = run_config.TaskType.PS +def _get_replica_id_integer(): + replica_id = ds_context.get_replica_context().replica_id_in_sync_group + if isinstance(replica_id, ops.Tensor): + replica_id = tensor_util.constant_value(replica_id) + return replica_id + + class ParameterServerStrategyTestBase( multi_worker_test_base.MultiWorkerTestBase): @@ -80,12 +94,11 @@ class ParameterServerStrategyTestBase( worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.test_session(target=self._default_target, - config=sess_config) as sess, \ + self.cached_session(target=self._default_target, + config=sess_config) as sess, \ d.scope(): - # Define a variable outside the call_for_each_tower scope. This is not - # recommended. + # Define a variable outside the call_for_each_replica scope. n = variable_scope.get_variable('n', initializer=10.0) self.assertEqual(n.device, '/job:ps/task:0') @@ -93,9 +106,8 @@ class ParameterServerStrategyTestBase( if num_gpus == 0: last_part_device = 'device:CPU:0' else: - last_part_device = ( - 'device:GPU:%d' % - distribution_strategy_context.get_tower_context().tower_id) + replica_id = _get_replica_id_integer() + last_part_device = ('device:GPU:%d' % replica_id) a = constant_op.constant(1.0) b = constant_op.constant(2.0) @@ -165,7 +177,7 @@ class ParameterServerStrategyTestBase( self.assertIn('/job:ps/', h.device) return y_add, z_add, f - y, z, f = d.call_for_each_tower(model_fn) + y, z, f = d.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -177,39 +189,108 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) + def _test_device_assignment_distributed_enable_partitioner( + self, task_type, task_id, num_gpus): + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) + num_shards = len(d.parameter_devices) + partitioner = partitioned_variables.fixed_size_partitioner(num_shards) + with ops.Graph().as_default(), \ + self.cached_session(target=self._default_target, + config=sess_config) as sess, \ + d.scope(): + + n = variable_scope.get_variable( + 'n', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + + for part_id, var in enumerate(n): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + + def model_fn(): + a = constant_op.constant([3.0, 5.0]) + # The device scope is ignored for variables but not for normal ops. + with ops.device('/job:worker/task:0'): + x = variable_scope.get_variable( + 'x', + initializer=constant_op.constant([10.0, 20.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + x_add = x.assign_add(a, name='x_add') + # The variable x is on the task 1 since the device_function has been + # called once before the model_fn. + for part_id, var in enumerate(x): + self.assertEqual(var.device, '/job:ps/task:%d' % part_id) + self.assertEqual(var.device, x_add[part_id].device) + + # The colocate_vars_with can override the distribution's device. + with d.colocate_vars_with(x_add[0]): + y = variable_scope.get_variable( + 'y', + initializer=constant_op.constant([20.0, 10.0]), + aggregation=variable_scope.VariableAggregation.SUM, + partitioner=partitioner) + y_add = y.assign_add( + [array_ops.identity(x_add[0]), + array_ops.identity(x_add[1])]) + + for part_id, var in enumerate(y): + self.assertEqual(var.device, '/job:ps/task:0') + self.assertEqual(y_add[part_id].device, var.device) + self.assertEqual(var.device, x_add[0].device) + + return x_add, y_add + + x, y = d.call_for_each_replica(model_fn) + + if context.num_gpus() >= 1: + variables.global_variables_initializer().run() + x_val, y_val = sess.run([x, y]) + if num_gpus < 1: + self.assertEqual(x_val, [13.0, 25.0]) + self.assertEqual(y_val, [33.0, 35.0]) + else: + x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] + y_expect = [ + 20.0 + x_expect[0] * num_gpus, 10.0 + x_expect[1] * num_gpus + ] + self.assertEqual(x_val, x_expect) + self.assertEqual(y_val, y_expect) + def _test_device_assignment_local(self, d, compute_device='CPU', variable_device='CPU', num_gpus=0): with ops.Graph().as_default(), \ - self.test_session(target=self._default_target, - config=self._sess_config) as sess, \ + self.cached_session(target=self._default_target, + config=self._sess_config) as sess, \ d.scope(): def model_fn(): if 'CPU' in compute_device: - tower_compute_device = '/device:CPU:0' + replica_compute_device = '/device:CPU:0' else: - tower_compute_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_tower_context().tower_id) - tower_compute_device = device_util.canonicalize(tower_compute_device) + replica_id = _get_replica_id_integer() + replica_compute_device = ('/device:GPU:%d' % replica_id) + replica_compute_device = device_util.canonicalize( + replica_compute_device) if 'CPU' in variable_device: - tower_variable_device = '/device:CPU:0' + replica_variable_device = '/device:CPU:0' else: - tower_variable_device = ( - '/device:GPU:%d' % - distribution_strategy_context.get_tower_context().tower_id) - tower_variable_device = device_util.canonicalize(tower_variable_device) + replica_id = _get_replica_id_integer() + replica_variable_device = ('/device:GPU:%d' % replica_id) + replica_variable_device = device_util.canonicalize( + replica_variable_device) a = constant_op.constant(1.0) b = constant_op.constant(2.0) c = a + b - self.assertEqual(a.device, tower_compute_device) - self.assertEqual(b.device, tower_compute_device) - self.assertEqual(c.device, tower_compute_device) + self.assertEqual(a.device, replica_compute_device) + self.assertEqual(b.device, replica_compute_device) + self.assertEqual(c.device, replica_compute_device) # The device scope is ignored for variables but not for normal ops. with ops.device('/device:GPU:2'): @@ -219,7 +300,7 @@ class ParameterServerStrategyTestBase( x_add = x.assign_add(c) e = a + c self.assertEqual( - device_util.canonicalize(x.device), tower_variable_device) + device_util.canonicalize(x.device), replica_variable_device) self.assertEqual(x_add.device, x.device) self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) @@ -232,7 +313,7 @@ class ParameterServerStrategyTestBase( # non-distributed values. y_add = y.assign_add(array_ops.identity(x_add)) self.assertEqual( - device_util.canonicalize(y.device), tower_variable_device) + device_util.canonicalize(y.device), replica_variable_device) self.assertEqual(y_add.device, y.device) self.assertEqual(y.device, x.device) @@ -240,7 +321,7 @@ class ParameterServerStrategyTestBase( 'z', initializer=10.0, aggregation=variable_scope.VariableAggregation.SUM) self.assertEqual( - device_util.canonicalize(z.device), tower_variable_device) + device_util.canonicalize(z.device), replica_variable_device) with ops.control_dependencies([y_add]): # We add an identity here to avoid complaints about summing @@ -248,7 +329,7 @@ class ParameterServerStrategyTestBase( z_add = z.assign_add(array_ops.identity(y)) with ops.control_dependencies([z_add]): f = z + c - self.assertEqual(f.device, tower_compute_device) + self.assertEqual(f.device, replica_compute_device) # The device scope would merge with the default worker device. with ops.device('/CPU:1'): @@ -261,11 +342,13 @@ class ParameterServerStrategyTestBase( u = variable_scope.get_variable('u', initializer=30.0) h = f + 1.0 self.assertEqual( - device_util.canonicalize(u.device), tower_variable_device) - self.assertEqual(device_util.canonicalize(x.device), h.device) + device_util.canonicalize(u.device), replica_variable_device) + self.assertEqual( + device_util.canonicalize(x.device), + device_util.canonicalize(h.device)) return y_add, z_add, f - y, z, f = d.call_for_each_tower(model_fn) + y, z, f = d.call_for_each_replica(model_fn) self.assertNotEqual(y, None) self.assertNotEqual(z, None) self.assertNotEqual(f, None) @@ -280,15 +363,15 @@ class ParameterServerStrategyTestBase( def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if 'chief' in d._cluster_spec.as_dict(): + if d.extended._cluster_spec: + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d.extended._cluster_spec.as_dict(): num_workers += 1 else: num_workers = 1 with ops.Graph().as_default(), \ - self.test_session(target=master_target, - config=sess_config) as sess, \ + self.cached_session(target=master_target, + config=sess_config) as sess, \ d.scope(): def model_fn(): @@ -300,7 +383,7 @@ class ParameterServerStrategyTestBase( aggregation=variable_scope.VariableAggregation.SUM) z = variable_scope.get_variable( 'z', initializer=30.0, - aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER) + aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) # We explicitly make a constant tensor here to avoid complaints about # summing non-distributed values. @@ -312,10 +395,10 @@ class ParameterServerStrategyTestBase( train_op = control_flow_ops.group(x_add, y_add, z_add) return x, y, z, train_op - x, y, z, train_op = d.call_for_each_tower(model_fn) + x, y, z, train_op = d.call_for_each_replica(model_fn) train_op = d.group(train_op) - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True if task_id == 0: @@ -340,24 +423,29 @@ class ParameterServerStrategyTestBase( self._finish_condition.release() x_val, y_val, z_val = sess.run([x, y, z]) - self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers) - self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers) + self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) + self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) self.assertEqual(z_val, 30.0 + 1.0 * num_workers) - return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and - y_val == 20.0 + 1.0 * num_workers * d.num_towers and + return (x_val == 10.0 + 1.0 * num_workers * d.num_replicas_in_sync and + y_val == 20.0 + 1.0 * num_workers * d.num_replicas_in_sync and z_val == 30.0 + 1.0 * num_workers) def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( task_type, task_id, num_gpus) - assert hasattr(d, '_cluster_spec') and d._cluster_spec - num_workers = len(d._cluster_spec.as_dict().get(WORKER)) - if CHIEF in d._cluster_spec.as_dict(): - num_workers += 1 + if task_type: + # Multi-worker + assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec + num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d.extended._cluster_spec.as_dict(): + num_workers += 1 + else: + # local + num_workers = 1 with ops.Graph().as_default(), \ - self.test_session(target=master_target, - config=sess_config) as sess, \ + self.cached_session(target=master_target, + config=sess_config) as sess, \ d.scope(): l = core.Dense(1, use_bias=False) @@ -384,7 +472,7 @@ class ParameterServerStrategyTestBase( def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_tower(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] after_list = [] @@ -393,8 +481,8 @@ class ParameterServerStrategyTestBase( before_list.append(fetched) with ops.control_dependencies([fetched]): # TODO(yuefengz): support non-Mirrored variable as destinations. - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies( d.update(v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -402,10 +490,12 @@ class ParameterServerStrategyTestBase( before_out, after_out = step() - if context.num_gpus() < d._num_gpus_per_worker: + if context.num_gpus() < d.extended._num_gpus_per_worker: return True - if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id): + if (not task_type or + multi_worker_util.is_chief( + d.extended._cluster_spec, task_type, task_id)): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. @@ -428,8 +518,40 @@ class ParameterServerStrategyTestBase( self.assertLess(error_after, error_before) return error_after < error_before + def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn, + expected_values): + distribution, master_target, config = self._get_test_objects( + task_type, task_id, num_gpus) + devices = distribution.extended.worker_devices + + with ops.Graph().as_default(), \ + self.cached_session(config=config, + target=master_target) as sess: + iterator = distribution.make_input_fn_iterator(input_fn) + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + sess.run([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + sess.run(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = sess.run( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + strategy_test_lib.DistributionTestBase, parameterized.TestCase): @classmethod @@ -438,6 +560,13 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] + def test_num_replicas_in_sync(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + # All the devices on a given worker are in sync which in this case is the + # number of gpus on each worker. + self.assertEqual(2, distribution.num_replicas_in_sync) + def testDeviceAssignmentLocalCPU(self): distribution = parameter_server_strategy.ParameterServerStrategy( num_gpus_per_worker=0) @@ -461,6 +590,12 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, def testDeviceAssignmentDistributed(self, num_gpus): self._test_device_assignment_distributed('worker', 1, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): + self._test_device_assignment_distributed_enable_partitioner( + 'worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, self._cluster_spec, context.num_gpus()) @@ -472,10 +607,82 @@ class ParameterServerStrategyTest(ParameterServerStrategyTestBase, @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): + def testMinimizeLossGraphDistributed(self, num_gpus): self._run_between_graph_clients(self._test_minimize_loss_graph, self._cluster_spec, num_gpus) + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraphLocal(self, num_gpus): + self._test_minimize_loss_graph(None, None, num_gpus) + + # TODO(priyag): Refactor this and other multi worker tests. + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorDistributed(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=3, + expected_input_pipeline_id=1) # because task_id = 1 + self._test_input_fn_iterator('worker', 1, num_gpus, + input_fn, expected_values) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[1, 2], required_gpus=1)) + def testMakeInputFnIteratorLocal(self, num_gpus): + if context.num_gpus() < num_gpus: + self.skipTest('Not enough GPUs') + dataset_fn = lambda: dataset_ops.Dataset.range(100) + expected_values = [[i+j for j in range(num_gpus)] + for i in range(0, 100, num_gpus)] + + input_fn = self._input_fn_to_test_input_context( + dataset_fn, + expected_num_replicas_in_sync=num_gpus, + expected_num_input_pipelines=1, + expected_input_pipeline_id=0) # only one worker and pipeline for local. + self._test_input_fn_iterator(None, None, num_gpus, + input_fn, expected_values) + + def testGlobalStepUpdate(self): + strategy = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=context.num_gpus()) + self._test_global_step_update(strategy) + + def testUpdateConfigProtoMultiWorker(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + distribution.configure( + cluster_spec=self._cluster_spec, task_type='worker', task_id=1) + + config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) + + new_config = distribution.update_config_proto(config_proto) + + # Verify device filters. + self.assertEqual(['/job:worker/task:1', '/job:ps'], + new_config.device_filters) + + # Verify isolate_session_state + self.assertFalse(new_config.isolate_session_state) + + def testUpdateConfigProtoLocal(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + + config_proto = config_pb2.ConfigProto() + new_config = distribution.update_config_proto(config_proto) + + # Verify isolate_session_state + self.assertTrue(new_config.isolate_session_state) + class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, parameterized.TestCase): @@ -509,6 +716,19 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, self.assertIs(values.AggregatingVariable, type(created_step)) self.assertIs(values.AggregatingVariable, type(get_step)) + def testValueContainer(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + with ops.Graph().as_default(), distribution.scope(): + def f(): + with backprop.GradientTape() as tape: + v = variable_scope.get_variable('v', initializer=10.0) + _ = v * v + v, = tape.watched_variables() + w = distribution.extended.value_container(v) + self.assertIs(values.AggregatingVariable, type(w)) + distribution.extended.call_for_each_replica(f) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py deleted file mode 100644 index d48aa9c89bc894a6afc4aab8b60fabc52a06b198..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Extension of prefetching_ops to support more than one device.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.python.data.experimental.ops import prefetching_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.data.util import sparse -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.util import nest - - -# pylint: disable=protected-access -class _PrefetchToDeviceIterator(object): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset. - one_shot: If true, we make a one shot iterator that's already initialized. - devices: Devices on which to prefetch. - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be shared - under the given name across multiple sessions that share the same devices - (e.g. when using a remote server). Only used if one_shot is False. - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - one_shot, - devices, - buffer_size, - shared_name=None): - self._input_dataset = input_dataset - self._get_next_call_count = 0 - self._one_shot = one_shot - if shared_name is None: - shared_name = "" - self._devices = devices - - if self._one_shot: - self._input_iterator = input_dataset.make_one_shot_iterator() - else: - self._input_iterator = iterator_ops.Iterator.from_structure( - self._input_dataset.output_types, self._input_dataset.output_shapes, - shared_name, self._input_dataset.output_classes) - input_iterator_handle = self._input_iterator.string_handle() - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self._input_iterator.output_types, - self._input_iterator.output_shapes, - self._input_iterator.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - target_device = ged_ops.experimental_iterator_get_device( - self._input_iterator._iterator_resource) - self._buffering_resources = [] - for device in nest.flatten(self._devices): - with ops.device(device): - buffer_resource_handle = prefetching_ops.function_buffering_resource( - f=_prefetch_fn, - output_types=data_nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes)), - target_device=target_device, - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=shared_name) - self._buffering_resources.append(buffer_resource_handle) - - if not self._one_shot: - reset_ops = [] - for buffer_resource in self._buffering_resources: - reset_ops.append( - ged_ops.experimental_function_buffering_resource_reset( - buffer_resource)) - with ops.control_dependencies(reset_ops): - self._initializer = self._input_iterator.make_initializer( - self._input_dataset) - - def get_next(self, name=None): - """See `tf.data.Iterator.get_next`.""" - self._get_next_call_count += 1 - if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: - warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - - flat_result = [] - # TODO(priyag): This will fail if the input size (typically number of - # batches) is not divisible by number of devices. - # How do we handle that more gracefully / let the user know? - for buffer_resource in self._buffering_resources: - flat_ret = ged_ops.experimental_function_buffering_resource_get_next( - buffer_resource, - output_types=data_nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - name=name) - - ret = sparse.deserialize_sparse_tensors( - data_nest.pack_sequence_as(self.output_types, flat_ret), - self.output_types, self.output_shapes, self.output_classes) - - for tensor, shape in zip( - data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): - if isinstance(tensor, ops.Tensor): - tensor.set_shape(shape) - flat_result.append(ret) - - return nest.pack_sequence_as(self._devices, flat_result) - - @property - def initializer(self): - if self._one_shot: - raise NotImplementedError("Can't initialize a one_shot_iterator") - return self._initializer - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -# pylint: enable=protected-access - - -class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): - """A `Dataset` whose iterator prefetches elements to other device(s).""" - - def __init__(self, input_dataset, devices, buffer_size): - super(_PrefetchToDeviceDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._devices = devices - self._buffer_size = buffer_size if buffer_size is not None else 1 - - def make_one_shot_iterator(self): - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=True, - devices=self._devices, - buffer_size=self._buffer_size) - - def make_initializable_iterator(self, shared_name=None): - if context.executing_eagerly(): - raise RuntimeError( - "make_initializable_iterator is not supported when eager " - "execution is enabled.") - - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=False, - devices=self._devices, - buffer_size=self._buffer_size, - shared_name=shared_name) - - def _as_variant_tensor(self): - # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset - # transformation methods is called. - # TODO(mrry): Investigate support for chaining further transformations after - # the prefetch, including GPU support. - raise NotImplementedError("`prefetch_to_devices()` must be the last " - "transformation in a dataset pipeline.") - - # TODO(priyag): Fix the output types, shapes and classes to match the result - # of get_next (which has the additional nesting layer of devices now). - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -def prefetch_to_devices(devices, buffer_size=None): - """A transformation that prefetches dataset values to the given `devices`. - - NOTE: Although the transformation creates a `tf.data.Dataset`, the - transformation must be the final `Dataset` in the input pipeline. - - Args: - devices: A nested structure of devices on which to prefetch the data. It can - be a single device name, or a tuple or list of device names. - buffer_size: (Optional.) The number of elements to buffer on each device. - Defaults to an automatically chosen value. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrefetchToDeviceDataset(dataset, devices, buffer_size) - - return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py deleted file mode 100644 index 16799104e8112f4391152c0cf2a15af81f8c2c9d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for prefetching_ops_v2.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test - - -class PrefetchingOpsV2Test(test.TestCase): - - def testPrefetchToOneDevice(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices("/gpu:0")) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToTwoDevicesInAList(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - output = [] - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - for _ in range(4): - result = sess.run(next_element) - self.assertEqual(2, len(result)) - output.extend(result) - self.assertEquals(set(range(8)), set(output)) - - def testPrefetchToTwoDevicesWithReinit(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 09b351ffa4165656e2fc9666ab4b7725ef061f50..be724fb59a7efa18c43c4cb98649ced806f7bcb4 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -90,7 +90,7 @@ def batchnorm_example(optimizer_fn, batch_per_epoch=1, momentum=0.9, renorm=False, - update_ops_in_tower_mode=False): + update_ops_in_replica_mode=False): """Example of non-distribution-aware legacy code with batch normalization.""" def dataset_fn(): @@ -113,7 +113,7 @@ def batchnorm_example(optimizer_fn, y = batchnorm(x, training=True) with ops.control_dependencies( ops.get_collection(ops.GraphKeys.UPDATE_OPS) - if update_ops_in_tower_mode else []): + if update_ops_in_replica_mode else []): loss = math_ops.reduce_mean( math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) # `x` and `y` will be fetched by the gradient computation, but not `loss`. diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 1b5a4f64e5bb1ffabfe1b87c150f713c755bb682..c928b6d9f1f21508edd753f94c38ab2723cc0a9f 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -50,7 +51,11 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) self._distributed_input = distribution.distribute_dataset(dataset_fn) - self._iterator = self._distributed_input.make_one_shot_iterator() + if context.executing_eagerly(): + self._iterator = self._distributed_input.make_one_shot_iterator() + else: + # TODO(priyag): Expose initializer via some initializer property. + self._iterator = self._distributed_input.make_initializable_iterator() class StandardSingleLossStep(StandardInputStep): @@ -85,25 +90,21 @@ class StandardSingleLossStep(StandardInputStep): super(StandardSingleLossStep, self).__init__(dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer - self._is_run_concurrently = False self._iterations_per_step = iterations_per_step def __call__(self): with self._distribution.scope(): - def step_fn(ctx, *inputs): + def step_fn(ctx, inputs): """Function to run one iteration with one input.""" gradients_fn = backprop.implicit_grad(self._loss_fn) gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) - grads_and_vars = self.distribution.call_for_each_tower( - gradients_fn, - ctx, *inputs, - run_concurrently=self._is_run_concurrently) + grads_and_vars = self.distribution.call_for_each_replica( + gradients_fn, args=(ctx,) + inputs) # If threads use layers, then we need to run the first step # sequentially, so that layers.build() is not executed in parallel. # Otherwise, multiple sets of mirrored variables are going to be # created. - self._is_run_concurrently = True return self._optimizer._distributed_apply( # pylint: disable=protected-access self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index f1ada49fa378358f112fb75a4bcdbe9a8a09cd13..1ff9b9ceec13351b098d47ed3ff62f689a625a31 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): run_step = single_loss_step else: with self.cached_session() as sess: + sess.run(single_loss_step._iterator.initializer) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index fd280f5754b34170cdd6b948236138d0e77dd8bc..d441b5af5f6aa41efde2c75d09d9589516c54992 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -19,16 +19,21 @@ from __future__ import division from __future__ import print_function from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import optimizer @@ -36,45 +41,43 @@ class _TestException(Exception): pass -# May be the argument to either distribution.call_for_each_tower() or -# get_tower_context().merge_call() +# May be the argument to either distribution.call_for_each_replica() or +# get_replica_context().merge_call() def _raise_exception_fn(_=None): raise _TestException() -# Must be the argument to a distribution.call_for_each_tower() call, calls a -# get_tower_context().merge_call() that raises an exception. +# Must be the argument to a distribution.call_for_each_replica() call, calls a +# get_replica_context().merge_call() that raises an exception. def _merge_raises_fn(): - distribution_strategy_context.get_tower_context().merge_call( - _raise_exception_fn) + ds_context.get_replica_context().merge_call(_raise_exception_fn) -# Must be the argument to a get_tower_context().merge_call() call, calls -# dist.call_for_each_tower() with a function that raises an exception. +# Must be the argument to a get_replica_context().merge_call() call, calls +# dist.call_for_each_replica() with a function that raises an exception. def _call_raises_fn(dist): - dist.call_for_each_tower(_raise_exception_fn) + dist.call_for_each_replica(_raise_exception_fn) -# Must be the argument to a distribution.call_for_each_tower() call, -# calls a get_tower_context().merge_call() that calls a -# call_for_each_tower() that raises an exception. +# Must be the argument to a distribution.call_for_each_replica() call, +# calls a get_replica_context().merge_call() that calls a +# call_for_each_replica() that raises an exception. def _merge_call_raises_fn(): - distribution_strategy_context.get_tower_context().merge_call(_call_raises_fn) + ds_context.get_replica_context().merge_call(_call_raises_fn) -# Must be the argument to a get_tower_context().merge_call() call, calls -# dist.call_for_each_tower() with a function that calls a -# get_tower_context().merge_call() that raises an exception. +# Must be the argument to a get_replica_context().merge_call() call, calls +# dist.call_for_each_replica() with a function that calls a +# get_replica_context().merge_call() that raises an exception. def _call_merge_raises_fn(dist): - dist.call_for_each_tower(_merge_raises_fn) + dist.call_for_each_replica(_merge_raises_fn) -# Must be the argument to a distribution.call_for_each_tower() call, calls a -# get_tower_context().merge_call() that calls a call_for_each_tower() that -# calls a get_tower_context().merge_call() that raises an exception. +# Must be the argument to a distribution.call_for_each_replica() call, calls a +# get_replica_context().merge_call() that calls a call_for_each_replica() that +# calls a get_replica_context().merge_call() that raises an exception. def _merge_call_merge_raises_fn(): - distribution_strategy_context.get_tower_context().merge_call( - _call_merge_raises_fn) + ds_context.get_replica_context().merge_call(_call_merge_raises_fn) class DistributionTestBase(test.TestCase): @@ -103,7 +106,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -113,8 +116,8 @@ class DistributionTestBase(test.TestCase): before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -138,7 +141,7 @@ class DistributionTestBase(test.TestCase): config.gpu_options.per_process_gpu_memory_fraction = 0.3 with context.graph_mode(), \ ops.Graph().as_default(), \ - self.test_session(config=config) as sess, \ + self.cached_session(config=config) as sess, \ d.scope(): l = core.Dense(1, use_bias=False) @@ -159,7 +162,7 @@ class DistributionTestBase(test.TestCase): def step(): """Perform one optimization step.""" # Run forward & backward to get gradients, variables list. - g_v = d.call_for_each_tower(grad_fn, one) + g_v = d.call_for_each_replica(grad_fn, args=(one,)) # Update the variables using the gradients and the update() function. before_list = [] @@ -168,8 +171,8 @@ class DistributionTestBase(test.TestCase): fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): - g = d.reduce( - variable_scope.VariableAggregation.SUM, g, destinations=v) + g = d.extended.reduce_to( + reduce_util.ReduceOp.SUM, g, destinations=v) with ops.control_dependencies(d.update( v, update, g, grouped=False)): after_list.append(d.read_var(v)) @@ -188,47 +191,103 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_map_reduce(self, d, in_graph=None): + def _test_replica_id(self, d): with d.scope(): - map_in = [constant_op.constant(i) for i in range(10)] - map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out, - "/device:CPU:0") - expected = 90 # 2 * (0 + 1 + ... + 9) - self.assertEqual(expected, observed.numpy()) - - def _test_device_index(self, d): - with d.scope(): - expected_devices = [False] * len(d.worker_devices) - - def mark_devices_fn(device_id): - self.assertLess(device_id, len(d.worker_devices)) - self.assertFalse(expected_devices[device_id]) - expected_devices[device_id] = True - - d.call_for_each_tower(mark_devices_fn, d.worker_device_index) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) - - def _test_tower_id(self, d): - with d.scope(): - expected_devices = [False] * len(d.worker_devices) + expected_devices = [False] * len(d.extended.worker_devices) def mark_devices_fn(): - tower_id = distribution_strategy_context.get_tower_context().tower_id - self.assertLess(tower_id, len(d.worker_devices)) - self.assertFalse(expected_devices[tower_id]) - expected_devices[tower_id] = True + replica_id = self.evaluate( + ds_context.get_replica_context().replica_id_in_sync_group) + self.assertLess(replica_id, len(d.extended.worker_devices)) + self.assertFalse(expected_devices[replica_id]) + expected_devices[replica_id] = True - d.call_for_each_tower(mark_devices_fn) - self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + d.call_for_each_replica(mark_devices_fn) + self.assertAllEqual(expected_devices, + [True] * len(d.extended.worker_devices)) def _test_call_and_merge_exceptions(self, dist): with dist.scope(): with self.assertRaises(_TestException): - dist.call_for_each_tower(_raise_exception_fn) + dist.call_for_each_replica(_raise_exception_fn) with self.assertRaises(_TestException): - dist.call_for_each_tower(_merge_raises_fn) + dist.call_for_each_replica(_merge_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_tower(_merge_call_raises_fn) + dist.call_for_each_replica(_merge_call_raises_fn) with self.assertRaises(_TestException): - dist.call_for_each_tower(_merge_call_merge_raises_fn) + dist.call_for_each_replica(_merge_call_merge_raises_fn) + + def _input_fn_to_test_input_context(self, + dataset_fn, + expected_num_replicas_in_sync, + expected_num_input_pipelines, + expected_input_pipeline_id): + # Use a list of one element as counter so that it can be captured by the + # `_input_fn`. This counter is incremented by 1 each time an input_fn is + # called. We use this counter to check whether the `input_pipeline_id` + # matches the counter in the in-graph replication. + worker_id_counter = [0] + + def _input_fn(input_context): + """Input fn for testing.""" + self.assertIsNotNone(input_context) + self.assertEqual(expected_num_replicas_in_sync, + input_context.num_replicas_in_sync) + self.assertEqual(expected_num_input_pipelines, + input_context.num_input_pipelines) + if expected_input_pipeline_id is not None: + self.assertEqual(expected_input_pipeline_id, + input_context.input_pipeline_id) + else: + self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) + worker_id_counter[0] += 1 + + return dataset_fn() + + return _input_fn + + def _test_input_fn_iterator(self, iterator, devices, expected_values, + sess=None): + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(iterator.initialize()) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) + + def _test_global_step_update(self, strategy): + with strategy.scope(): + global_step = variable_scope.get_variable( + "global_step", + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) + self.evaluate(variables.global_variables_initializer()) + + def model_fn(): + train_op = global_step.assign_add(1) + value = global_step.read_value() + return train_op, value + + train_ops, value = strategy.call_for_each_replica(model_fn) + self.evaluate(strategy.group(train_ops)) + global_step_tensors = strategy.unwrap(value) + global_step_values = self.evaluate(global_step_tensors) + self.assertEqual((1,) * len(global_step_tensors), global_step_values) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1d9e299b38409b874610765e54fa0052fafd5f4b..b6f5b492017fc7dfd329e69ad9ca418ae682bc4b 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,30 +21,34 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import values +import copy +import functools + from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as session_lib +from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE" - - def get_tpu_system_metadata(tpu_cluster_resolver): """Retrieves TPU system metadata given a TPUClusterResolver.""" master = tpu_cluster_resolver.master() @@ -75,13 +79,13 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, # synchronization settings? # Get aggregation value - # TODO(jhseu): Support aggregation in a tower context. + # TODO(jhseu): Support aggregation in a replica context. aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in [ vs.VariableAggregation.NONE, vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN, - vs.VariableAggregation.ONLY_FIRST_TOWER, + vs.VariableAggregation.ONLY_FIRST_REPLICA, ]: raise ValueError("Invalid variable aggregation mode: {} for variable: {}" .format(aggregation, kwargs["name"])) @@ -112,9 +116,8 @@ def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args, return result -# TODO(jhseu): Stop inheriting from OneDeviceStrategy. -class TPUStrategy(one_device_strategy.OneDeviceStrategy): - """Experimental TPU distribution strategy implementation.""" +class TPUStrategy(distribute_lib.DistributionStrategy): + """TPU distribution strategy implementation.""" def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): """Initializes the TPUStrategy object. @@ -130,10 +133,24 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): num_cores: Number of cores to use on the TPU. If None specified, then auto-detect the cores and topology of the TPU system. """ - # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the - # master node fetched from the cluster resolver. - super(TPUStrategy, self).__init__("/device:CPU:0") + super(TPUStrategy, self).__init__(TPUExtended( + self, tpu_cluster_resolver, steps_per_run, num_cores)) + + @property + def steps_per_run(self): + """DEPRECATED: use .extended.steps_per_run instead.""" + return self._extended.steps_per_run + + +class TPUExtended(distribute_lib.DistributionStrategyExtended): + """Implementation of TPUStrategy.""" + # Track what TPU devices have been initialized. + _initialized_devices = [] + + def __init__(self, container_strategy, tpu_cluster_resolver, steps_per_run, + num_cores=None): + super(TPUExtended, self).__init__(container_strategy) self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override @@ -143,19 +160,45 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # parallelism. device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name} - self._device_index = values.PerDevice(device_map) - self._tpu_devices = sorted(device_map.keys()) - # Only create variables for the number of towers we're running. - self._tpu_devices = self._tpu_devices[:self.num_towers] + self._device_index = values.PerReplica(device_map) + self._host_device = self.get_host_cpu_device(0) + self._tpu_devices = tuple(sorted(device_map.keys())) + # Only create variables for the number of replicas we're running. + self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run - self._require_static_shapes = True - def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes, - iterations): + # Initialize the TPU devices. + self._initialize_tpu() + + def _initialize_tpu(self): + """Initialize the TPU devices in a separate session and graph. + + We keep track of all the TPU devices that we're initialized as we should + only be running TPU initialize once for the entire process. + """ + master = self._tpu_cluster_resolver.master() + # Verify TPU has not already been initialized in this process. + if master in TPUExtended._initialized_devices: + logging.info("TPU master %s has already been initialized." % master) + return + + logging.info("Initializing the TPU system.") + session_config = config_pb2.ConfigProto(allow_soft_placement=True) + self._configure(session_config) + with ops.Graph().as_default(): + with session_lib.Session(config=session_config, target=master) as sess: + sess.run([tpu.initialize_system()]) + logging.info("Finized initializing TPU system.") + + # Update Strategy state to make sure we can track device initialization. + TPUExtended._initialized_devices.append(master) + + def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator, + input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. The while_loop op returned will run `iterations` times and in each run @@ -163,7 +206,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): Args: host_id: integer, id of the host to run the enqueue ops on. - iterator: `tf.data` iterator to read the input data. + multi_worker_iterator: MultiWorkerDataIterator to read the input data. input_shapes: shape of inputs to be enqueue on the queue. This is same as the value of `nest.flatten(iterator.output_shapes)`. iterations: integer, number of iterations to be run; determines the @@ -174,6 +217,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): on the infeed queue from the host with id `host_id` for each device shard. """ host = self.get_host_cpu_device(host_id) + # TODO(sourabhbajaj): Possibly make changes to MultiWorkerDataset + # to work with TPU Prefetch so clean up this code. + iterator = ( + multi_worker_iterator.get_iterator(self.get_host(host_id))._iterator) # pylint: disable=protected-access def _infeed_enqueue_ops_fn(): """Enqueue ops for one iteration.""" @@ -182,7 +229,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): enqueue_ops = [] with ops.device(host): - for _ in range(self.num_towers_per_host): + for _ in range(self.num_replicas_per_host): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) @@ -211,44 +258,59 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return enqueue_op_per_host - def distribute_dataset(self, dataset_fn): - # TODO(priyag): Perhaps distribute across cores here. - return self._call_dataset_fn(dataset_fn) + def _make_dataset_iterator(self, dataset): + """Make iterators for each of the TPU hosts.""" + + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + return values.DatasetIterator(dataset, worker_devices, + self._num_replicas_in_sync) + + def _distribute_dataset(self, dataset_fn): + worker_devices = [ + (self.get_host(hid), [self.get_host_cpu_device(hid)]) + for hid in range(self.num_hosts) + ] + return values.MultiWorkerDataset( + functools.partial(self._call_dataset_fn, dataset_fn), worker_devices) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. - def _run_steps_on_dataset(self, fn, iterator, iterations, - initial_loop_values=None): - - shapes = nest.flatten(iterator.output_shapes) - if any([not s.is_fully_defined() for s in shapes]): + def _experimental_run_steps_on_iterator( + self, fn, multi_worker_iterator, iterations, initial_loop_values=None): + output_shapes = multi_worker_iterator.output_shapes + shapes = nest.flatten(output_shapes) + if any(not s.is_fully_defined() for s in shapes): raise ValueError( - 'TPU currently requires fully defined shapes. Either use ' - 'set_shape() on the input tensors or use ' - 'dataset.batch(..., drop_remainder=True).') - types = nest.flatten(iterator.output_types) + "TPU currently requires fully defined shapes. Either use " + "set_shape() on the input tensors or use " + "dataset.batch(..., drop_remainder=True).") + types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ - self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations) + self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, + iterations) for host_id in range(self.num_hosts)] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - return nest.pack_sequence_as(iterator.output_shapes, dequeued) + return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = values.MultiStepContext() - def run_fn(*args, **kwargs): + + def run_fn(): """Single step on the TPU device.""" - del args, kwargs fn_inputs = dequeue_fn() if not isinstance(fn_inputs, tuple): fn_inputs = (fn_inputs,) - fn_result = fn(ctx, *fn_inputs) + fn_result = fn(ctx, fn_inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): @@ -256,11 +318,6 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): else: return fn_result - # TODO(sourabhbajaj): The input to while loop should be based on the output - # type of the step_fn - def iterate_on_tpu(): - return training_loop.repeat(iterations, run_fn, initial_loop_values) - # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer @@ -270,74 +327,98 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access - replicate_inputs = [[]] * self.num_towers - replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) + def rewrite_fn(*args): + """The rewritten step fn running on TPU.""" + del args + replicate_inputs = [[]] * self._num_replicas_in_sync + replicate_outputs = tpu.replicate(run_fn, replicate_inputs) + + # If run_fn has tensor outputs, tpu.replicate returns a list of list. We + # will flatten it in this case. If run_fn has no tensor outputs, + # tpu.replicate returns a list of no_ops, we will keep the output as it + # is. + if isinstance(replicate_outputs[0], list): + replicate_outputs = nest.flatten(replicate_outputs) + + return replicate_outputs + + # TODO(sourabhbajaj): The input to while loop should be based on the output + # type of the step_fn + assert isinstance(initial_loop_values, list) + initial_loop_values = initial_loop_values * self._num_replicas_in_sync + + # Put the while loop op on host 0. + with ops.device(self.get_host_cpu_device(0)): + replicate_outputs = training_loop.repeat(iterations, rewrite_fn, + initial_loop_values) + del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) - # Filter out any ops from the outputs, typically this would be the case - # when there were no tensor outputs. - last_step_tensor_outputs = [x for x in replicate_outputs - if not isinstance(x, ops.Operation)] - - # Outputs are currently of the structure (grouped by device) - # [[output0_device0, output1_device0, output2_device0], - # [output0_device1, output1_device1, output2_device1]] - # Convert this to the following structure instead: (grouped by output) - # [[output0_device0, output0_device1], - # [output1_device0, output1_device1], - # [output2_device0, output2_device1]] - last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] + if isinstance(replicate_outputs, list): + # Filter out any ops from the outputs, typically this would be the case + # when there were no tensor outputs. + last_step_tensor_outputs = [ + x for x in replicate_outputs if not isinstance(x, ops.Operation) + ] + + # Outputs are currently of the structure (flattened) + # [output0_device0, output1_device0, output2_device0, + # output0_device1, output1_device1, output2_device1, + # ...] + # Convert this to the following structure instead: (grouped by output) + # [[output0_device0, output0_device1], + # [output1_device0, output1_device1], + # [output2_device0, output2_device1]] + output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync + last_step_tensor_outputs = [ + last_step_tensor_outputs[i::output_num] for i in range(output_num) + ] + else: + # no tensors returned. + last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) - for (name, aggregation) in ctx._last_step_outputs_aggregations.items(): # pylint: disable=protected-access + for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] - # For outputs that have already been aggregated, take the first value + # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. - # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value. - if aggregation is not variables_lib.VariableAggregation.NONE: + # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica + # value. + if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx - def _call_for_each_tower(self, fn, *args, **kwargs): - # TODO(jhseu): Consider making it so call_for_each_tower implies that we're - # in a tpu.rewrite(), and update TPUMirroredVariable accordingly. - kwargs.pop('run_concurrently', None) - with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access + def _call_for_each_replica(self, fn, args, kwargs): + # TODO(jhseu): Consider making it so call_for_each_replica implies that + # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. + with _TPUReplicaContext(self._container_strategy()): return fn(*args, **kwargs) - def initialize(self): + def _initialize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError('Eager mode not supported in TPUStrategy.') + raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - # TODO(jhseu): We need this hack because DistributionStrategies must be - # pickleable for copy.deepcopy(). Remove when initialize_system goes away. - graph = ops.get_default_graph() - tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - if tpu_init: - return tpu_init - graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION, - tpu.initialize_system()) - return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION) - - def finalize(self): + return [] + + def _finalize(self): if context.executing_eagerly(): # TODO(priyag): Add appopriate call here when eager is supported for TPUs. - raise NotImplementedError('Eager mode not supported in TPUStrategy.') + raise NotImplementedError("Eager mode not supported in TPUStrategy.") else: - return [tpu.shutdown_system()] + return [] def _get_devices_from(self, colocate_with=None): - # TODO(jhseu): Change this when we support model parallelism. + # TODO(jhseu): Change this when we support model parallelism. return self._tpu_devices def _create_variable(self, next_creator, *args, **kwargs): @@ -352,7 +433,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - # We append a / to variable names created on towers with id > 0 to + # We append a / to variable names created on replicas with id > 0 to # ensure that we ignore the name scope and instead use the given # name as the absolute name of the variable. kwargs["name"] = "%s/replica_%d/" % (var0name, i) @@ -374,12 +455,12 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args, **kwargs) - def _reduce(self, aggregation, value, destinations): + def _reduce_to(self, reduce_op, value, destinations): if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. - value *= (1. / self.num_towers) - elif aggregation != vs.VariableAggregation.SUM: + value *= (1. / self._num_replicas_in_sync) + elif reduce_op != reduce_util.ReduceOp.SUM: raise NotImplementedError( "Currently only support sum & mean in TPUStrategy.") return tpu_ops.cross_replica_sum(value) @@ -387,27 +468,22 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # Validate that the destination is same as the host device # Note we don't do this when in replicate context as the reduction is # performed on the TPU device itself. - devices = cross_tower_ops_lib.get_devices_from(destinations) + devices = cross_device_ops_lib.get_devices_from(destinations) if len(devices) == 1: assert device_util.canonicalize(devices[0]) == device_util.canonicalize( - self.get_host_cpu_device(0)) + self._host_device) else: - raise ValueError('Multiple devices are not supported for TPUStrategy') + raise ValueError("Multiple devices are not supported for TPUStrategy") - if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: - return value[0] output = math_ops.add_n(value) - if aggregation == vs.VariableAggregation.MEAN: + if reduce_op == reduce_util.ReduceOp.MEAN: return output * (1. / len(value)) return output - def _update(self, var, options, fn, *args, **kwargs): + def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) - should_group = options.pop("grouped") - assert not options # Validate that we are processing all of the options. - if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access - if should_group: + if group: return fn(var, *args, **kwargs) else: return [fn(var, *args, **kwargs)] @@ -422,9 +498,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) - return values.update_regroup(self, updates, should_group) - - # TODO(josh11b): Need to implement _update_non_slot()! + return values.update_regroup(self, updates, group) def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) @@ -433,33 +507,39 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def _unwrap(self, val): if isinstance(val, values.DistributedValues): # Return in a deterministic order. - return [val.get(device=d) for d in sorted(val.devices)] + return tuple(val.get(device=d) for d in sorted(val.devices)) elif isinstance(val, list): # TODO(josh11b): We need to remove this case; per device values should - # be represented using a PerDevice wrapper instead of a list with + # be represented using a PerReplica wrapper instead of a list with # one entry per device. - return val - return [val] + return tuple(val) + return (val,) + def value_container(self, value): + return value - @property - def num_towers(self): - return self._num_cores_override or self._tpu_metadata.num_cores + def _broadcast_to(self, tensor, destinations): + del destinations + return tensor @property def num_hosts(self): return self._tpu_metadata.num_hosts @property - def num_towers_per_host(self): + def num_replicas_per_host(self): return self._tpu_metadata.num_of_cores_per_host @property - def between_graph(self): + def _num_replicas_in_sync(self): + return self._num_cores_override or self._tpu_metadata.num_cores + + @property + def experimental_between_graph(self): return False @property - def should_init(self): + def experimental_should_init(self): return True @property @@ -478,20 +558,65 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def parameter_devices(self): return self._tpu_devices + def non_slot_devices(self, var_list): + return self._host_device + + def _update_non_slot(self, colocate_with, fn, args, kwargs, group): + del colocate_with + with ops.device(self._host_device), distribute_lib.UpdateContext( + self._host_device): + result = fn(*args, **kwargs) + if group: + return result + else: + return nest.map_structure(self._unwrap, result) + + def get_host(self, host_id): + if self._tpu_cluster_resolver.get_master() in ("", "local"): + return "/replica:0/task:0" + job_name = self._tpu_cluster_resolver.get_job_name() or "tpu_worker" + return "/job:%s/task:%d" % (job_name, host_id) + def get_host_cpu_device(self, host_id): - if self._tpu_cluster_resolver.get_master() in ('', 'local'): - return '/replica:0/task:0/device:CPU:0' - job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker' - return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id) - - def configure(self, - session_config=None, - cluster_spec=None, - task_type=None, - task_id=None): + return self.get_host(host_id) + "/device:CPU:0" + + def _configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): del cluster_spec, task_type, task_id if session_config: - session_config.isolate_session_state = True - cluster_spec = self._tpu_cluster_resolver.cluster_spec() - if cluster_spec: - session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + session_config.CopyFrom(self._update_config_proto(session_config)) + + def _update_config_proto(self, config_proto): + updated_config = copy.deepcopy(config_proto) + updated_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + return updated_config + + # TODO(priyag): Delete this once all strategies use global batch size. + @property + def _global_batch_size(self): + return True + + +class _TPUReplicaContext(distribute_lib.ReplicaContext): + """Replication Context class for TPU Strategy.""" + + # TODO(sourabhbajaj): Call for each tower should be updating this. + def __init__(self, distribution_strategy): + distribute_lib.ReplicaContext.__init__( + self, + distribution_strategy, + # TODO(b/118385803): properly initialize replica_id, instead of always 0 + replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)) + + @property + def devices(self): + distribute_lib.require_replica_context(self) + ds = self._distribution_strategy + replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) + return (ds.extended.worker_devices[replica_id],) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py deleted file mode 100644 index 18ceba42c2a57917de1de315973cd111d9a022cf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distribute/python/values.py +++ /dev/null @@ -1,1614 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Various classes representing distributed values. - -See go/tf-distribution-strategy. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import contextlib -import weakref -import six - -from tensorflow.contrib.distribute.python import input_ops -from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.python.eager import context -from tensorflow.python.eager import tape -from tensorflow.python.framework import device as tf_device -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_resource_variable_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util -from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.training import distribution_strategy_context -from tensorflow.python.training import saver -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.util import nest - - -# pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects? -# Not sure DeviceSpec objects are usable as a dict key. -class DistributedValues(object): - """Holds a map from device to values. Either PerDevice or Mirrored.""" - - def __init__(self, index): - self._index = {device_util.canonicalize(key): value - for key, value in six.iteritems(index)} - - def get(self, device=None): - """Returns the value for the current device or raises a ValueError.""" - if device is None: - tower_context = distribution_strategy_context.get_tower_context() - if tower_context: - device = tower_context.device - else: - device = distribute_lib.get_update_device() - if device is None: - return self._get_cross_tower() - device = device_util.canonicalize(device) - try: - return self._index[device] - except KeyError as e: - six.raise_from( - ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())), e) - - def on_device(self, device): - device = device_util.canonicalize(device) - return device in self._index - - @property - def devices(self): - return list(self._index.keys()) - - @property - def is_tensor_like(self): - for v in self._index.values(): - if not tensor_util.is_tensor(v): - return False - return True - - def __str__(self): - return "%s:%s" % (self.__class__.__name__, self._index) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._index) - - # TODO(josh11b): Possibly make an accessor for _index for use by - # DistributionStrategy implementations. - - -class DistributedDelegate(DistributedValues): - """A map from device to values; acts as the same type as the values.""" - - def __init__(self, index): - super(DistributedDelegate, self).__init__(index) - - def __getattr__(self, name): - return getattr(self.get(), name) - - # pylint: disable=multiple-statements - def __add__(self, o): return self.get() + o - def __radd__(self, o): return o + self.get() - def __sub__(self, o): return self.get() - o - def __rsub__(self, o): return o - self.get() - def __mul__(self, o): return self.get() * o - def __rmul__(self, o): return o * self.get() - def __truediv__(self, o): return self.get() / o - def __rtruediv__(self, o): return o / self.get() - def __floordiv__(self, o): return self.get() // o - def __rfloordiv__(self, o): return o // self.get() - def __mod__(self, o): return self.get() % o - def __rmod__(self, o): return o % self.get() - def __lt__(self, o): return self.get() < o - def __le__(self, o): return self.get() <= o - def __gt__(self, o): return self.get() > o - def __ge__(self, o): return self.get() >= o - def __and__(self, o): return self.get() & o - def __rand__(self, o): return o & self.get() - def __or__(self, o): return self.get() | o - def __ror__(self, o): return o | self.get() - def __xor__(self, o): return self.get() ^ o - def __rxor__(self, o): return o ^ self.get() - def __getitem__(self, o): return self.get()[o] - def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) - def __rpow__(self, o): return pow(o, self.get()) - def __invert__(self): return ~self.get() - def __neg__(self): return -self.get() - def __abs__(self): return abs(self.get()) - - def __div__(self, o): - try: - return self.get().__div__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rdiv__(self, o): - try: - return self.get().__rdiv__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __matmul__(self, o): - try: - return self.get().__matmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rmatmul__(self, o): - try: - return self.get().__rmatmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - # TODO(josh11b): Even more operator overloads. - - -class PerDevice(DistributedValues): - """Holds a map from device to unsynchronized values.""" - pass - - -# Note that unlike PerDevice, Mirrored values inherit from -# DistributedDelegate and so can be used directly in cross-tower mode. -class Mirrored(DistributedDelegate): - """Holds a map from device to values which are kept in sync.""" - - def _get_cross_tower(self): - device = device_util.canonicalize(device_util.current()) - if device in self._index: - return self._index[device] - return list(self._index.values())[0] - - def _as_graph_element(self): - obj = self.get() - # pylint: disable=protected-access - conv_fn = getattr(obj, "_as_graph_element", None) - if conv_fn and callable(conv_fn): - return conv_fn() - return obj - - -def _assign_on_device(device, variable, tensor): - with ops.device(device): - return variable.assign(array_ops.identity(tensor)) - - -DistributedVarOp = collections.namedtuple( - "DistributedVarOp", ["name", "graph", "type"]) - - -class DistributedVariable(DistributedDelegate): - """Holds a map from device to variables.""" - # TODO(josh11b): Support changing the set of variables if e.g. if new - # devices are joining or a device is to leave. - - def __init__(self, index): - # Child class must set self._primary_var before calling - # super(...).__init__(index). - self._common_name = self._primary_var.name.split(":")[0] - # Use a weakref to make it easy to map from the contained values - # to the container without introducing a reference cycle. - for v in six.itervalues(index): - v._distributed_container = weakref.ref(self) # pylint: disable=protected-access - # tf.keras keeps track of variables initialized using this attribute. When - # tf.keras gets the default session, it initializes all uninitialized vars. - # We need to make _keras_initialized a member of DistributedVariable because - # without this it will use `__getattr__` which will delegate to a component - # variable. - self._keras_initialized = False - # Typically, a `DistributedVariable`'s initializer is composed of the - # initializers of the components variables. However, in some cases, such as - # when restoring from a checkpoint, we may set the _initializer_op - # property on the entire `DistributedVariable`. - self._initializer_op = None - super(DistributedVariable, self).__init__(index) - - def is_initialized(self, name=None): - """Identifies if all the component variables are initialized. - - Args: - name: Name of the final `logical_and` op. - - Returns: - The op that evaluates to True or False depending on if all the - component variables are initialized. - """ - # We have to cast the self._index.values() to a `list` because when we - # use `model_to_estimator` to run tf.keras models, self._index.values() is - # of type `dict_values` and not `list`. - values_list = list(self._index.values()) - result = values_list[0].is_initialized() - # We iterate through the list of values except the last one to allow us to - # name the final `logical_and` op the same name that is passed by the user - # to the `is_initialized` op. For distributed variables, the - # `is_initialized` op is a `logical_and` op. - for v in values_list[1:-1]: - result = math_ops.logical_and(result, v.is_initialized()) - result = math_ops.logical_and(result, values_list[-1].is_initialized(), - name=name) - return result - - @property - def initializer(self): - if self._initializer_op: - init_op = self._initializer_op - else: - # return grouped ops of all the var initializations of component values of - # the mirrored variable - init_op = control_flow_ops.group( - [v.initializer for v in self._index.values()]) - return init_op - - @property - def graph(self): - return self._primary_var.graph - - @property - def _shared_name(self): - return self._common_name - - @property - def _unique_id(self): - return self._primary_var._unique_id # pylint: disable=protected-access - - @property - def name(self): - return self._primary_var.name - - @property - def dtype(self): - return self._primary_var.dtype - - @property - def shape(self): - return self._primary_var.shape - - def get_shape(self): - return self._primary_var.get_shape() - - def to_proto(self, export_scope=None): - return self._primary_var.to_proto(export_scope=export_scope) - - @property - def op(self): - # We want cross-tower code that does some var.op.X calls - # to work (even if the current device isn't in self.devices), but - # other uses of var.op in a cross-tower context to fail. - if distribution_strategy_context.get_cross_tower_context(): - return DistributedVarOp(self._primary_var.op.name, - self._primary_var.op.graph, - self._primary_var.op.type) - return self.get().op - - @property - def _in_graph_mode(self): - return self._primary_var._in_graph_mode # pylint: disable=protected-access - - def read_value(self): - return distribution_strategy_context.get_distribution_strategy().read_var( - self) - - def _should_act_as_resource_variable(self): - """Pass resource_variable_ops.is_resource_variable check.""" - pass - - -ops.register_dense_tensor_like_type(DistributedVariable) - - -class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): - """Class for defining how to restore a MirroredVariable.""" - - def __init__(self, mirrored_variable, primary_variable, name): - self._mirrored_variable = mirrored_variable - super(_MirroredSaveable, self).__init__(primary_variable, "", name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return control_flow_ops.group([ - _assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access - - -class MirroredVariable(DistributedVariable, Mirrored, - checkpointable.CheckpointableBase): - """Holds a map from device to variables whose values are kept in sync.""" - - def __init__(self, index, primary_var, aggregation): - self._primary_var = primary_var - self._aggregation = aggregation - super(MirroredVariable, self).__init__(index) - - # The arguments to update() are automatically unwrapped so the update() - # function would normally see regular variables, not MirroredVariables. - # However, the update function can still operate on wrapped MirroredVariables - # through object members, captured arguments, etc. This is more likely in an - # update_non_slot() function (like OptimizerV2._finish), which can - # update several non-slot variables in one call. - def _assign_func(self, *args, **kwargs): - f = kwargs.pop("f") - if distribution_strategy_context.get_cross_tower_context(): - update_device = distribute_lib.get_update_device() - if update_device is not None: - # We are calling an assign function on the mirrored variable in an - # update context. - v = self.get(device=update_device) - return f(v, *args, **kwargs) - - # We are calling assign on the mirrored variable in cross tower context, - # use update to update the variable. - strategy = distribution_strategy_context.get_distribution_strategy() - return strategy.update(self, f, *args, **kwargs) - else: - _assert_tower_context() - # We are calling an assign function on the mirrored variable in tower - # context. - # We reduce the value we want to assign/add/sub. More details about how we - # handle the different use cases can be found in the _reduce method. - # We call the function on each of the mirrored variables with the reduced - # value. - if self._aggregation == vs.VariableAggregation.NONE: - raise ValueError("You must specify an aggregation method to update a " - "MirroredVariable in Tower Context.") - - def merge_fn(strategy, value, *other_args, **other_kwargs): - return strategy.update( - self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), - *other_args, **other_kwargs) - - return distribution_strategy_context.get_tower_context().merge_call( - merge_fn, *args, **kwargs) - - def assign_sub(self, *args, **kwargs): - assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) - return self._assign_func(f=assign_sub_fn, *args, **kwargs) - - def assign_add(self, *args, **kwargs): - assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) - return self._assign_func(f=assign_add_fn, *args, **kwargs) - - def assign(self, *args, **kwargs): - assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) - return self._assign_func(f=assign_fn, *args, **kwargs) - - @property - def aggregation(self): - return self._aggregation - - def _get_cross_tower(self): - device = device_util.canonicalize(device_util.current()) - if device in self._index: - return array_ops.identity(self._index[device]) - return array_ops.identity(self._primary_var) - - def _as_graph_element(self): - # pylint: disable=protected-access - if distribution_strategy_context.get_cross_tower_context(): - return self._primary_var._as_graph_element() - return self.get()._as_graph_element() - - def _gather_saveables_for_checkpoint(self): - """Overrides CheckpointableBase method. - - This allows both name-based and object-based save and restore of - MirroredVariables. - - Returns: - A dictionary mapping attribute names to `SaveableObject` factories. - """ - def _saveable_factory(name=self._common_name): - return _MirroredSaveable(self, self._primary_var, name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(MirroredVariable, - _tensor_conversion_mirrored) - - -def _enclosing_tpu_context(): - # pylint: disable=protected-access - tpu_context = ops.get_default_graph()._get_control_flow_context() - # pylint: enable=protected-access - while tpu_context is not None and not isinstance( - tpu_context, control_flow_ops.XLAControlFlowContext): - tpu_context = tpu_context.outer_context - return tpu_context - - -# TODO(jhseu): Deduplicate code. We copy code because we don't want to -# inherit from DistributedDelegate. DistributedDelegate will not work in a -# tpu.replicate() because it assumes that you're in a device context where you -# can operate on a single version of the variable, but a tpu.replicate() -# operates on all variables and is replicated during a rewrite pass. -class TPUMirroredVariable(checkpointable.CheckpointableBase): - """Holds a map from device to TPU variables whose values are kept in sync.""" - - def __init__(self, index, primary_var, aggregation): - # Use a weakref to make it easy to map from the contained values - # to the container without introducing a reference cycle. - for v in six.itervalues(index): - v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access - self._index = {device_util.canonicalize(key): value - for key, value in six.iteritems(index)} - self._primary_var = primary_var - self._common_name = self._primary_var.name.split(":")[0] - self._aggregation = aggregation - # Needed for GradientTape - self._trainable = self._primary_var.trainable - - def _get(self, device=None): - """Returns the value for the current device or raises a ValueError.""" - if device is None: - tower_context = distribution_strategy_context.get_tower_context() - if tower_context: - device = tower_context.device - else: - device = distribute_lib.get_update_device() - if device is None: - return self._get_cross_tower() - device = device_util.canonicalize(device) - try: - return self._index[device] - except KeyError as e: - six.raise_from( - ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())), e) - - # pylint: disable=multiple-statements - def __add__(self, o): return self.read_value() + o - def __radd__(self, o): return o + self.read_value() - def __sub__(self, o): return self.read_value() - o - def __rsub__(self, o): return o - self.read_value() - def __mul__(self, o): return self.read_value() * o - def __rmul__(self, o): return o * self.read_value() - def __truediv__(self, o): return self.read_value() / o - def __rtruediv__(self, o): return o / self.read_value() - def __floordiv__(self, o): return self.read_value() // o - def __rfloordiv__(self, o): return o // self.read_value() - def __mod__(self, o): return self.read_value() % o - def __rmod__(self, o): return o % self.read_value() - def __lt__(self, o): return self.read_value() < o - def __le__(self, o): return self.read_value() <= o - def __gt__(self, o): return self.read_value() > o - def __ge__(self, o): return self.read_value() >= o - def __and__(self, o): return self.read_value() & o - def __rand__(self, o): return o & self.read_value() - def __or__(self, o): return self.read_value() | o - def __ror__(self, o): return o | self.read_value() - def __xor__(self, o): return self.read_value() ^ o - def __rxor__(self, o): return o ^ self.read_value() - def __getitem__(self, o): return self.read_value()[o] - def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo) - def __rpow__(self, o): return pow(o, self.read_value()) - def __invert__(self): return ~self.read_value() - def __neg__(self): return -self.read_value() - def __abs__(self): return abs(self.read_value()) - - def __div__(self, o): - try: - return self.read_value().__div__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rdiv__(self, o): - try: - return self.read_value().__rdiv__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __matmul__(self, o): - try: - return self.read_value().__matmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rmatmul__(self, o): - try: - return self.read_value().__rmatmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - @property - def handle(self): - # If we're in a tpu.rewrite(), return the replicated handle. - tpu_context = _enclosing_tpu_context() - if tpu_context is not None: - return tpu_context.get_replicated_var_handle( - self._common_name, nest.flatten(self._index)) - - device = distribute_lib.get_update_device() - if device is None: - return self._primary_var.handle - device = device_util.canonicalize(device) - try: - return self._index[device].handle - except KeyError as e: - six.raise_from( - ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())), e) - - # The arguments to update() are automatically unwrapped so the update() - # function would normally see regular variables, not MirroredVariables. - # However, the update function can still operate on wrapped MirroredVariables - # through object members, captured arguments, etc. This is more likely in an - # update_non_slot() function (like OptimizerV2._finish), which can - # update several non-slot variables in one call. - def _assign_func(self, *args, **kwargs): - if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy": - raise ValueError("You may only assign to a TPUMirroredVariable within a " - "TPUStrategy.") - f = kwargs.pop("f") - if distribution_strategy_context.get_cross_tower_context(): - if _enclosing_tpu_context() is not None: - return distribution_strategy_context.get_distribution_strategy().update( - self, f, *args, **kwargs) - - update_device = distribute_lib.get_update_device() - # We are calling update on the mirrored variable in cross tower context. - if update_device is not None: - # We are calling an assign function on the mirrored variable in cross - # tower context. - v = self._get(device=update_device) - return f(v, *args, **kwargs) - - return distribution_strategy_context.get_distribution_strategy().update( - self, f, *args, **kwargs) - else: - _assert_tower_context() - # We are calling an assign function on the mirrored variable in tower - # context. - # We reduce the value we want to assign/add/sub. More details about how we - # handle the different use cases can be found in the _reduce method. - # We call the function on each of the mirrored variables with the reduced - # value. - if self._aggregation == vs.VariableAggregation.NONE: - raise ValueError("You must specify an aggregation method to update a " - "TPUMirroredVariable in Tower Context.") - - def merge_fn(strategy, value, *other_args, **other_kwargs): - return strategy.update( - self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), - *other_args, **other_kwargs) - - return distribution_strategy_context.get_tower_context().merge_call( - merge_fn, *args, **kwargs) - - @contextlib.contextmanager - def _handle_graph(self, handle): - # Note: might have an eager tensor but not be executing eagerly when - # building functions. - if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) - or ops.has_default_graph()): - yield - else: - with handle.graph.as_default(): - yield - - @property - def trainable(self): - return self._trainable - - def _read_variable_op(self, parent_op=None): - if self.trainable: - tape.variable_accessed(self) - if parent_op is not None: - with ops.control_dependencies([parent_op]): - return gen_resource_variable_ops.read_variable_op( - self.handle, self.dtype) - - return gen_resource_variable_ops.read_variable_op( - self.handle, self.dtype) - - def read_value(self): - return self._read_variable_op() - - def assign_sub(self, *args, **kwargs): - def assign_sub_fn(var, delta, **kw): - name = kw.pop("name", None) - read_value = kw.pop("read_value", True) - with self._handle_graph(var.handle): - op = gen_resource_variable_ops.assign_sub_variable_op( - var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - if read_value: - return self._read_variable_op(parent_op=op) - return op - - return self._assign_func(f=assign_sub_fn, *args, **kwargs) - - def assign_add(self, *args, **kwargs): - def assign_add_fn(var, delta, **kw): - name = kw.pop("name", None) - read_value = kw.pop("read_value", True) - with self._handle_graph(var.handle): - op = gen_resource_variable_ops.assign_add_variable_op( - var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - if read_value: - return self._read_variable_op(parent_op=op) - return op - - return self._assign_func(f=assign_add_fn, *args, **kwargs) - - def assign(self, *args, **kwargs): - def assign_fn(var, value, **kw): - name = kw.pop("name", None) - read_value = kw.pop("read_value", True) - with self._handle_graph(var.handle): - op = gen_resource_variable_ops.assign_variable_op( - var.handle, ops.convert_to_tensor(value, dtype=self.dtype), - name=name) - if read_value: - return self._read_variable_op(parent_op=op) - return op - - return self._assign_func(f=assign_fn, *args, **kwargs) - - @property - def aggregation(self): - return self._aggregation - - @property - def constraint(self): - return None - - @property - def initializer(self): - return control_flow_ops.group( - [v.initializer for v in nest.flatten(self._index)]) - - @property - def graph(self): - return self._primary_var.graph - - @property - def _shared_name(self): - return self._common_name - - @property - def _unique_id(self): - return self._primary_var._unique_id # pylint: disable=protected-access - - @property - def name(self): - return self._primary_var.name - - @property - def dtype(self): - return self._primary_var.dtype - - @property - def shape(self): - return self._primary_var.shape - - def get_shape(self): - return self._primary_var.get_shape() - - def to_proto(self, export_scope=None): - return self._primary_var.to_proto(export_scope=export_scope) - - def _get_cross_tower(self): - device = device_util.canonicalize(device_util.current()) - if device in self._index: - return self._index[device] - return self._primary_var - - def _as_graph_element(self): - # pylint: disable=protected-access - if distribution_strategy_context.get_cross_tower_context(): - return self._primary_var._as_graph_element() - return self._read_variable_op() - - def _gather_saveables_for_checkpoint(self): - """Overrides CheckpointableBase method. - - This allows both name-based and object-based save and restore of - MirroredVariables. - - Returns: - A dictionary mapping attribute names to `SaveableObject` factories. - """ - def _saveable_factory(name=self._common_name): - return _MirroredSaveable(self, self._primary_var, name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - def _should_act_as_resource_variable(self): - """Pass resource_variable_ops.is_resource_variable check.""" - pass - - # Needed to pass ResourceVariable checks. - @property - def op(self): - return self._primary_var.op - - @property - def _in_graph_mode(self): - return self._primary_var._in_graph_mode # pylint: disable=protected-access - - def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): - """Converts a variable to a tensor.""" - # pylint: disable=protected-access - if _enclosing_tpu_context() is None: - return self._get()._dense_var_to_tensor(dtype, name, as_ref) - # pylint: enable=protected-access - if dtype is not None and dtype != self.dtype: - raise NotImplementedError - if as_ref: - return self.handle - else: - return self.read_value() - - def is_initialized(self, name=None): - """Identifies if all the component variables are initialized. - - Args: - name: Name of the final `logical_and` op. - - Returns: - The op that evaluates to True or False depending on if all the - component variables are initialized. - """ - # TODO(jhseu): Do we need TPU context implementation? - - # We have to cast the self._index.values() to a `list` because when we - # use `model_to_estimator` to run tf.keras models, self._index.values() is - # of type `dict_values` and not `list`. - values_list = nest.flatten(self._index) - result = values_list[0].is_initialized() - # We iterate through the list of values except the last one to allow us to - # name the final `logical_and` op the same name that is passed by the user - # to the `is_initialized` op. For distributed variables, the - # `is_initialized` op is a `logical_and` op. - for v in values_list[1:-1]: - result = math_ops.logical_and(result, v.is_initialized()) - result = math_ops.logical_and(result, values_list[-1].is_initialized(), - name=name) - return result - - -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False): - return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access - - -ops.register_tensor_conversion_function(TPUMirroredVariable, - _tensor_conversion_tpu_mirrored) -ops.register_dense_tensor_like_type(TPUMirroredVariable) - - -class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): - """Class for defining how to restore a TowerLocalVariable.""" - - def __init__(self, tower_local_variable, name): - self._tower_local_variable = tower_local_variable - # We use a callable so that we don't have to evaluate this expression - # in the case where we are trying to restore instead of save. - def tensor(): - return distribution_strategy_context.get_distribution_strategy().read_var( - tower_local_variable) - spec = saver.BaseSaverBuilder.SaveSpec( - tensor=tensor, - slice_spec="", - name=name, - dtype=tower_local_variable.dtype) - super(_TowerLocalSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - tensor, = restored_tensors - return self._tower_local_variable.assign(tensor) - - -def _assert_tower_context(): - if not distribution_strategy_context.get_tower_context(): - raise RuntimeError( - "Tower-local variables may only be assigned in a tower context.") - - -class TowerLocalVariable(DistributedVariable, PerDevice, - checkpointable.CheckpointableBase): - """Holds a map from device to variables whose values are reduced on save.""" - - def __init__(self, index, primary_var, aggregation): - self._primary_var = primary_var - self._aggregation = aggregation - super(TowerLocalVariable, self).__init__(index) - - def assign_sub(self, *args, **kwargs): - _assert_tower_context() - return self.get().assign_sub(*args, **kwargs) - - def assign_add(self, *args, **kwargs): - _assert_tower_context() - return self.get().assign_add(*args, **kwargs) - - def assign(self, *args, **kwargs): - if distribution_strategy_context.get_cross_tower_context(): - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - tensor = args[0] - if self._aggregation == vs.VariableAggregation.SUM: - tensor *= 1. / len(self.devices) - return control_flow_ops.group( - [_assign_on_device(d, v, tensor) - for d, v in six.iteritems(self._index)]) - else: - _assert_tower_context() - return self.get().assign(*args, **kwargs) - - @property - def aggregation(self): - return self._aggregation - - def _get_cross_tower(self): - if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER: - return self._primary_var - all_components = tuple(self._index.values()) - # TODO(josh11b): Use a strategy-specific method. - total = math_ops.add_n(all_components) - if self._aggregation == vs.VariableAggregation.MEAN: - return total * (1./ len(all_components)) - return total - - def _as_graph_element(self): - # pylint: disable=protected-access - if distribution_strategy_context.get_cross_tower_context(): - return self._get_cross_tower() - return self.get()._as_graph_element() - - def _gather_saveables_for_checkpoint(self): - """Overrides CheckpointableBase method. - - This allows both name-based and object-based save and restore of - TowerLocalVariables. - - Returns: - A dictionary mapping attribute names to `SaveableObject` factories. - """ - def _saveable_factory(name=self._common_name): - return _TowerLocalSaveable(self, name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - -# Register a conversion function for TowerLocalVariable which allows as_ref to -# be true. -def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(TowerLocalVariable, - _tensor_conversion_tower_local) - - -def _devices_match(d1, d2): - return device_util.canonicalize(d1) == device_util.canonicalize(d2) - - -def regroup(per_device, wrap_class=PerDevice): - """Makes device->nest map into a nest of PerDevice/Mirrored values.""" - items = list(per_device.items()) - assert items - v0 = items[0][1] # First value - - if isinstance(v0, list): - for _, v in items[1:]: - assert isinstance(v, list) - assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % - (len(v), len(v0), v, v0)) - return [regroup({k: v[i] for k, v in items}, wrap_class) - for i in range(len(v0))] - - if isinstance(v0, tuple): - for _, v in items[1:]: - assert isinstance(v, tuple) - assert len(v) == len(v0) - regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class) - for i in range(len(v0))) - if hasattr(v0, "_fields"): - # This tuple is in fact a namedtuple! Create a new namedtuple instance - # and initialize it with the regrouped values: - assert hasattr(type(v0), "_make") - return type(v0)._make(regrouped_tuple) - else: - return regrouped_tuple - - if isinstance(v0, dict): - v0keys = set(v0.keys()) - for _, v in items[1:]: - assert isinstance(v, dict) - assert set(v.keys()) == v0keys - return {key: regroup({k: v[key] for k, v in items}, wrap_class) - for key in v0keys} - - # If exactly the same object across all devices, return it unwrapped. - same_id = True - for _, v in items[1:]: - if v is not v0: - same_id = False - break - # Consider three cases where same_id is true: - # * If v0 is a DistributedVariable (a MirroredVariable or - # TowerLocalVariable, and same_id means it is the same across all - # devices), we want to return it. We check DistributedVariable - # specifically since it can look like it has a - # _distributed_container member since its members do. - # * If v0 is a member of a distributed variable, in which case - # hasattr(v0, "_distributed_container") is true, we want to - # return the DistributedVariable that contains it using the - # _distributed_container logic below. This case can trigger - # same_id when there is only one device. - # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, DistributedVariable) or - not hasattr(v0, "_distributed_container")): - return v0 - - # Detect the case where each device has a parallel component of the - # same MirroredVariable (or TowerLocalVariable). In this case we - # want to return the containing MirroredVariable, after a bunch of - # sanity checking. In particular, each component should have the - # same container, and the devices of the variables should match the - # keys of the per-device dictionary. - if hasattr(v0, "_distributed_container"): - # pylint: disable=protected-access - assert not isinstance(v0, MirroredVariable), ( - "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) - assert _devices_match(v0.device, items[0][0]), ( - "v0.device = %s, items = %s" % (v0.device, items)) - distributed_container = v0._distributed_container() - assert distributed_container is not None - for d, v in items[1:]: - assert _devices_match(v.device, d), ( - "v.device = %s, d = %s, items = %s" % (v.device, d, items)) - assert distributed_container is v._distributed_container() - return distributed_container - # pylint: enable=protected-access - - return wrap_class(per_device) - - -def select_device(device, structured): - """Specialize a nest of regular & per-device values for one device.""" - def _get(x): - return x.get(device) if isinstance(x, DistributedValues) else x - - return nest.map_structure(_get, structured) - - -def select_device_mirrored(device, structured): - """Specialize a nest of regular & mirrored values for one device.""" - def _get_mirrored(x): - if isinstance(x, DistributedValues): - if not isinstance(x, Mirrored): - raise TypeError( - "Expected value to be mirrored across towers: %s in %s." % - (x, structured)) - return x.get(device) - else: - return x - - return nest.map_structure(_get_mirrored, structured) - - -def update_regroup(strategy, updates, should_group): - """Regroup for an update, with dependencies to ensure all updates execute.""" - regrouped = regroup(updates, Mirrored) - if not should_group: - return nest.map_structure(strategy.unwrap, regrouped) - grouped_flat = [] - for u in nest.flatten(regrouped): - if isinstance(u, DistributedValues): - g = strategy.group(u) - if u.is_tensor_like: - # Make sure we run all updates. Without this, something like - # session.run(strategy.update(...)) may only update one tower. - index = {} - for d in u.devices: - with ops.device(d), ops.control_dependencies([g]): - index[d] = array_ops.identity(u.get(d)) - g = Mirrored(index) - else: - g = u - grouped_flat.append(g) - return nest.pack_sequence_as(regrouped, grouped_flat) - - -class PerDeviceDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" - - def __init__(self, iterator, devices, prefetch_on_device=None): - self._iterator = iterator - self._devices = devices - self._prefetch_on_device = prefetch_on_device - - @property - def initializer(self): - return self._iterator.initializer - - def get_next(self, name=None): - """Scatter the input across devices.""" - if self._prefetch_on_device: - data_list = self._iterator.get_next(name=name) - index = dict(zip(self._devices, data_list)) - else: - batch = self._iterator.get_next(name=name) - index = {} - def get_ith(i): - return lambda x: x[i] - - for i, d in enumerate(self._devices): - index[d] = nest.map_structure(get_ith(i), batch) - if context.executing_eagerly(): - with ops.device(d): - index[d] = nest.map_structure(array_ops.identity, index[d]) - - return regroup(index) - - -class PerDeviceDataset(object): - """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" - - def __init__(self, dataset, devices, prefetch_on_device=None): - self._devices = devices - - # Default to using prefetching in graph mode, unless specified. - # TODO(priyag): Enable prefetching in eager mode. - self._prefetch_on_device = prefetch_on_device - if self._prefetch_on_device is None: - self._prefetch_on_device = not context.executing_eagerly() - assert not (self._prefetch_on_device and context.executing_eagerly()), ( - "Prefetching is only supported in graph mode currently") - - if self._prefetch_on_device: - self._dataset = dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - else: - # TODO(priyag): If dropping remainder is not appropriate, find another - # approach to distributing the dataset when not possible to divide evenly. - # Possibly not an issue when we start using PartitionedDataset. - self._dataset = dataset.batch(len(devices), drop_remainder=True) - - def make_one_shot_iterator(self): - """Get a one time use iterator for the distributed PerDeviceDataset.""" - dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) - - def make_initializable_iterator(self): - """Get an initializable iterator for the distributed PerDeviceDataset.""" - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) - - -class MultiWorkerDataIterator(object): - """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" - - def __init__(self, iterators, worker_device_map): - """Initialize the MultiWorkerDataIterator object. - - Args: - iterators: a dict mapping from each worker to an iterator for - that worker. - worker_device_map: a dict mapping from each worker's devices to a list of - devices that belong to this worker. - - Raises: - ValueError: if iterators and worker_device_map are not compatible. - """ - self._iterators = iterators - self._worker_device_map = worker_device_map - if set(self._iterators) != set(self._worker_device_map): - raise ValueError("iterators and worker_device_map are not compatible.") - - @property - def initializer(self): - return control_flow_ops.group( - [iterator.initializer for iterator in self._iterators.values()]) - - def get_next(self, name=None): - """Scatter the input across hosts and devices.""" - index = {} - for worker, iterator in six.iteritems(self._iterators): - if name is not None: - d = tf_device.DeviceSpec.from_string(worker) - new_name = "%s_%s_%d" % (name, d.job, d.task) - else: - new_name = None - with ops.device(worker): - data_per_worker = iterator.get_next(name=new_name) - - worker_devices = self._worker_device_map[worker] - # Ungroup these per-device value so as to get a flat map from devices to - # values. - for d in worker_devices: - v = select_device(d, data_per_worker) - if d in index: - raise ValueError("Duplicated devices in worker_device_map: %r" % v) - index[d] = v - - return regroup(index) - - -class MultiWorkerDataset(object): - """Like a `tf.data.Dataset` that distributes data to different workers. - - Each worker gets one shard of the input dataset. It is currently not working - in - eager mode. - """ - - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, - auto_shard=False): - """Initialize the MultiWorkerDataset object. - - Args: - dataset_fn: a function that returns a `tf.data.Dataset`. - worker_device_map: a dict mapping from each worker to a list of devices - that belong to this worker. - prefetch_on_device: whether to prefetch to devices. - auto_shard: whether to auto-shard the dataset. - """ - self._worker_device_map = worker_device_map - self._datasets = {} - # TODO(yuefengz, priyag): support different set of jobs for input - # processing. - for i, (worker, worker_devices) in enumerate( - six.iteritems(worker_device_map)): - with ops.device(worker): - worker_input = dataset_fn() - if auto_shard: - worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) - self._datasets[worker] = PerDeviceDataset( - worker_input, worker_devices, prefetch_on_device=prefetch_on_device) - - def make_one_shot_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): - with ops.device(worker): - iterators[worker] = dataset.make_one_shot_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) - - def make_initializable_iterator(self): - iterators = {} - for worker, dataset in six.iteritems(self._datasets): - with ops.device(worker): - iterators[worker] = dataset.make_initializable_iterator() - return MultiWorkerDataIterator(iterators, self._worker_device_map) - - -class _PerKey(object): - """Holds data associated by keys.""" - - def __init__(self, *index): - # pylint: disable=protected-access - self._index = list(index) - - def get(self, iteration): - return array_ops.gather(self._index, iteration) - - def get_shape(self): - return self._index[-1][-1].get_shape() - - def get_dtype(self): - return self._index[-1][-1].dtype - - def __str__(self): - return "%s:%s" % (self.__class__.__name__, self._index) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._index) - - -class PerIteration(_PerKey): - """Holds input for multiple iterations at once.""" - - def __init__(self, *index): - # pylint: disable=protected-access - super(PerIteration, self).__init__(*[batch._index for batch in index]) - - -class Batches(_PerKey): - pass - - -class MultiIterator(object): - """Iterator that returns results of multiple get_next()s.""" - - def __init__(self, dataset_iterator, iterations, batches_per_iteration): - self._dataset_iterator = dataset_iterator - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def get_next(self, name=None): - """Return PerIteration with `iterations x batches_per_iteration` inputs.""" - data = [] - for _ in range(self._batches_per_iteration): - batch = [] - for _ in range(self._iterations): - batch.append(self._dataset_iterator.get_next(name=name)) - data.append(batch) - - # Here is an example. Suppose each get_next returns a tuple of two tensors. - # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: - # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] - # - # After the first `map_structure` it gets transformed to: - # [(Batches(a, A), Batches(z, Z)), - # (Batches(b, B), Batches(y, Y)), - # (Batches(c, C), Batches(x, X))] - # - # After the second `map_structure` it gets transformed to a tuple of: - # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), - # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) - - data = nest.map_structure(Batches, *data) - data = nest.map_structure(PerIteration, *data) - - return data - - @property - def initializer(self): - return self._dataset_iterator.initializer - - -class PerIterationDataset(object): - """A dataset that returns MultiIterators.""" - - def __init__(self, dataset, iterations, batches_per_iteration): - self._dataset = dataset - self._iterations = iterations - self._batches_per_iteration = batches_per_iteration - - def make_one_shot_iterator(self): - iterator = self._dataset.make_one_shot_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - def make_initializable_iterator(self): - iterator = self._dataset.make_initializable_iterator() - return MultiIterator(iterator, self._iterations, - self._batches_per_iteration) - - -class MapOutput(object): - """Map can result in multiple outputs per device.""" - - def __init__(self, l): - self._l = l - - def get(self): - return self._l - - -class MultiStepContext(object): - """A context object that can be used to capture things when running steps. - - This context object is useful when running multiple steps at a time using the - `run_steps_on_dataset` API. For e.g. it allows the user's step function to - specify which outputs to emit at what frequency. Currently it supports - capturing output from the last step, as well as capturing non tensor outputs. - In the future it will be augmented to support other use cases such as output - each N steps. - """ - - def __init__(self): - """Initializes an output context. - - Returns: - A context object. - """ - self._last_step_outputs = {} - self._last_step_outputs_aggregations = {} - self._non_tensor_outputs = {} - - @property - def last_step_outputs(self): - """A dictionary consisting of outputs to be captured on last step. - - Keys in the dictionary are names of tensors to be captured, as specified - when `set_last_step_output` is called. - Values in the dictionary are the tensors themselves. If - `set_last_step_output` was called with an `aggregation` for this output, - then the value is the aggregated value. - - Returns: - A dictionary with last step outputs. - """ - return self._last_step_outputs - - def _set_last_step_outputs(self, outputs): - """Replace the entire dictionary of last step outputs.""" - if not isinstance(outputs, dict): - raise ValueError("Need a dictionary to set last_step_outputs.") - self._last_step_outputs = outputs - - def set_last_step_output(self, name, output, - aggregation=variables_lib.VariableAggregation.NONE): - """Set `output` with `name` to be outputted from the last step. - - Args: - name: String, name to identify the output. Doesn't need to match tensor - name. - output: The tensors that should be outputted with `name`. See below for - actual types supported. - aggregation: Aggregation method to use to aggregate outputs from multiple - towers. Required if `set_last_step_output` is called in a tower context. - Optional in cross_tower_context. - When present, the outputs from all the towers are aggregated using the - current distribution strategy's `reduce` method. Hence, the type of - `output` must be what's supported by the corresponding `reduce` method. - For e.g. if using MirroredStrategy and aggregation is set, output - must be a `PerDevice` value. - The aggregation method is also recorded in a dictionary - `_last_step_outputs_aggregations` for later interpreting of the - outputs as already reduced or not. - - """ - if distribution_strategy_context.get_cross_tower_context(): - self._last_step_outputs_aggregations[name] = aggregation - if aggregation is variables_lib.VariableAggregation.NONE: - self._last_step_outputs[name] = output - else: - distribution = distribution_strategy_context.get_distribution_strategy() - self._last_step_outputs[name] = distribution.reduce( - aggregation, output, destinations="/device:CPU:0") - else: - assert aggregation is not variables_lib.VariableAggregation.NONE - def merge_fn(distribution, value): - self._last_step_outputs[name] = distribution.reduce( - aggregation, value, destinations="/device:CPU:0") - # Setting this inside the `merge_fn` because all towers share the same - # context object, so it's more robust to set it only once (even if all - # the towers are trying to set the same value). - self._last_step_outputs_aggregations[name] = aggregation - - distribution_strategy_context.get_tower_context().merge_call( - merge_fn, output) - - @property - def non_tensor_outputs(self): - """A dictionary consisting of any non tensor outputs to be captured.""" - return self._non_tensor_outputs - - def set_non_tensor_output(self, name, output): - """Set `output` with `name` to be captured as a non tensor output.""" - if distribution_strategy_context.get_cross_tower_context(): - self._non_tensor_outputs[name] = output - else: - def merge_fn(distribution, value): - # NOTE(priyag): For non tensor outputs, we simply return all the values - # in a list as aggregation doesn't make sense on non tensors. - self._non_tensor_outputs[name] = distribution.unwrap(value) - distribution_strategy_context.get_tower_context().merge_call( - merge_fn, output) - - -def value_container(val): - """Returns the container that this per-device `value` belongs to. - - Args: - val: A value returned by `call_for_each_tower()` or a variable - created in `scope()`. - - Returns: - A container that `value` belongs to. - If value does not belong to any container (including the case of - container having been destroyed), returns the value itself. - """ - # pylint: disable=protected-access - if (hasattr(val, "_distributed_container") and - # DistributedVariable has _distributed_container defined - # but we don't want to return it. - not isinstance(val, DistributedVariable)): - container = val._distributed_container() - # pylint: disable=protected-access - if container is not None: - return container - return val - - -# TODO(josh11b): Descend from Variable. -class AggregatingVariable(checkpointable.CheckpointableBase): - """A wrapper around a variable that aggregates updates across towers.""" - - def __init__(self, v, aggregation): - self._v = v - # TODO(josh11b): Set v._distributed_container? - # v._distributed_container = weakref.ref(self) # pylint: disable=protected-access - self._aggregation = aggregation - - def get(self): - return self._v - - def __getattr__(self, name): - return getattr(self._v, name) - - def _assign_func(self, *args, **kwargs): - f = kwargs.pop("f") - if distribution_strategy_context.get_cross_tower_context(): - update_device = distribute_lib.get_update_device() - if update_device is not None: - # We are calling an assign function in an update context. - return f(self._v, *args, **kwargs) - - # We are calling an assign function in cross tower context, wrap it in an - # update call. - return distribution_strategy_context.get_distribution_strategy().update( - self, f, *args, **kwargs) - else: - assert distribution_strategy_context.get_tower_context() - # We are calling an assign function in tower context. - # We reduce the value we want to assign/add/sub. More details about how we - # handle the different use cases can be found in the _reduce method. - # We call the function with the reduced value. - if self._aggregation == vs.VariableAggregation.NONE: - raise ValueError("You must specify an aggregation method to update a " - "a variable in Tower Context.") - - def merge_fn(strategy, value, *other_args, **other_kwargs): - return strategy.update( - self, f, - strategy.reduce( - aggregation=self._aggregation, value=value, destinations=self), - *other_args, **other_kwargs) - - return distribution_strategy_context.get_tower_context().merge_call( - merge_fn, *args, **kwargs) - - def assign_sub(self, *args, **kwargs): - assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) - return self._assign_func(f=assign_sub_fn, *args, **kwargs) - - def assign_add(self, *args, **kwargs): - assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) - return self._assign_func(f=assign_add_fn, *args, **kwargs) - - def assign(self, *args, **kwargs): - assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) - return self._assign_func(f=assign_fn, *args, **kwargs) - - @property - def aggregation(self): - return self._aggregation - - @property - def name(self): - return self._v.name - - @property - def dtype(self): - return self._v.dtype - - # TODO(josh11b): Test saving & restoring. - def _gather_saveables_for_checkpoint(self): - return {checkpointable.VARIABLE_VALUE_KEY: self._v} - - # pylint: disable=multiple-statements - def __add__(self, o): return self._v + o - def __radd__(self, o): return o + self._v - def __sub__(self, o): return self._v - o - def __rsub__(self, o): return o - self._v - def __mul__(self, o): return self._v * o - def __rmul__(self, o): return o * self._v - def __truediv__(self, o): return self._v / o - def __rtruediv__(self, o): return o / self._v - def __floordiv__(self, o): return self._v // o - def __rfloordiv__(self, o): return o // self._v - def __mod__(self, o): return self._v % o - def __rmod__(self, o): return o % self._v - def __lt__(self, o): return self._v < o - def __le__(self, o): return self._v <= o - def __gt__(self, o): return self._v > o - def __ge__(self, o): return self._v >= o - def __and__(self, o): return self._v & o - def __rand__(self, o): return o & self._v - def __or__(self, o): return self._v | o - def __ror__(self, o): return o | self._v - def __xor__(self, o): return self._v ^ o - def __rxor__(self, o): return o ^ self._v - def __getitem__(self, o): return self._v[o] - def __pow__(self, o, modulo=None): return pow(self._v, o, modulo) - def __rpow__(self, o): return pow(o, self._v) - def __invert__(self): return ~self._v - def __neg__(self): return -self._v - def __abs__(self): return abs(self._v) - - def __div__(self, o): - try: - return self._v.__div__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rdiv__(self, o): - try: - return self._v.__rdiv__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __matmul__(self, o): - try: - return self._v.__matmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __rmatmul__(self, o): - try: - return self._v.__rmatmul__(o) - except AttributeError: - # See https://docs.python.org/3/library/constants.html#NotImplemented - return NotImplemented - - def __str__(self): - return str(self._v) - - def __repr__(self): - return repr(self._v) - - def _should_act_as_resource_variable(self): - """Pass resource_variable_ops.is_resource_variable check.""" - pass - - -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function( - AggregatingVariable, _tensor_conversion_aggregate) -ops.register_dense_tensor_like_type(AggregatingVariable) diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 121d2fbb3fbccd913599a581b3de9850ab33eae0..538b859f3d1ece55b460f6dbf8f01540a6013381 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,14 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os +from absl.testing import parameterized -from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.estimator import model_fn as model_fn_lib @@ -35,10 +37,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest @@ -190,10 +192,10 @@ def _make_mirrored(): class RegroupAndSelectDeviceTest(test.TestCase): - def _is_per_device(self, result, expected, klass=values.PerDevice): + def _is_per_replica(self, result, expected, klass=values.PerReplica): self.assertIsInstance(result, klass) # We canonicalize the devices to match the device strings returned - # by PerDevice, which also does device string canonicalization. + # by PerReplica, which also does device string canonicalization. devices = [device_util.canonicalize(_device_str(i)) for i in range(len(expected))] self.assertEqual(set(devices), set(result.devices)) @@ -206,18 +208,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): _nested_value("2")}) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"]) - self._is_per_device(result[2], ["h1", "h2"]) + self._is_per_replica(result[0], ["a1", "a2"]) + self._is_per_replica(result[2], ["h1", "h2"]) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"]) - self._is_per_device(result[1][2], ["g1", "g2"]) + self._is_per_replica(result[1][0], ["b1", "b2"]) + self._is_per_replica(result[1][2], ["g1", "g2"]) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"]) - self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"]) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"]) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -238,18 +240,18 @@ class RegroupAndSelectDeviceTest(test.TestCase): values.Mirrored) self.assertIsInstance(result, tuple) self.assertEqual(3, len(result)) - self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) - self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored) self.assertIsInstance(result[1], list) self.assertEqual(3, len(result[1])) - self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) - self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored) self.assertIsInstance(result[1][1], dict) self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) - self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) - self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored) # Also test that we can undo the merge using select_device() self.assertEqual(_nested_value("1"), @@ -275,7 +277,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): _device_str(1): ("b", foo)}) self.assertIsInstance(result, tuple) self.assertEqual(2, len(result)) - self._is_per_device(result[0], ["a", "b"]) + self._is_per_replica(result[0], ["a", "b"]) self.assertIs(foo, result[1]) # Test select_device(), should undo the merge done by regroup(). @@ -325,72 +327,46 @@ class RegroupAndSelectDeviceTest(test.TestCase): self.assertTrue( isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) - self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) for device_id in range(3): d = _device_str(device_id) - self.assertEquals(created_estimator_specs[device_id].loss, - merged_estimator_spec.loss.get(d)) - self.assertEquals(created_estimator_specs[device_id].train_op, - merged_estimator_spec.train_op.get(d)) + self.assertEqual(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEqual(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) # Scaffold is populated by `EstimatorSpec.__new__`. - self.assertEquals(created_estimator_specs[device_id].scaffold, - merged_estimator_spec.scaffold.get(d)) + self.assertEqual(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) # Also test that we can undo the merge using select_device() - self.assertEquals(created_estimator_specs[device_id], - values.select_device(_device_str(device_id), - merged_estimator_spec)) + self.assertEqual(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) -class PerDeviceDatasetTest(test.TestCase): +class PerReplicaDatasetTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True - def _test_iterator_no_prefetch(self, devices, dataset, expected_values): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_one_shot_iterator() + def _test_iterator(self, devices, dataset, expected_values): + per_replica_dataset = values.PerReplicaDataset(dataset, devices) + if context.executing_eagerly(): + iterator = per_replica_dataset.make_one_shot_iterator() + else: + iterator = per_replica_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() - actual = self.evaluate([ - values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, actual) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() self.evaluate([ values.select_device(d, next_element) for d in devices]) - def _test_iterator_with_prefetch(self, devices, dataset, expected_values): - if not context.executing_eagerly(): - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_one_shot_iterator() - - # With prefetching, we cannot guarantee which input ends up on which - # device, so we verify that the complete set seen on all devices is - # correct, and equal numbers are distributed to each device. - combined_actual = [] - combined_expected = [] - for expected_value in expected_values: - next_element = iterator.get_next() - combined_actual.extend( - self.evaluate( - [values.select_device(d, next_element) for d in devices])) - combined_expected.extend(expected_value) - - self.assertEqual(set(combined_expected), set(combined_actual)) - - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - self.evaluate([ - values.select_device(d, next_element) for d in devices]) - - def _test_iterator(self, devices, dataset, expected_values): - self._test_iterator_no_prefetch(devices, dataset, expected_values) - self._test_iterator_with_prefetch(devices, dataset, expected_values) - @test_util.run_in_graph_and_eager_modes def testOneDevice(self): devices = ["/device:CPU:0"] @@ -445,9 +421,8 @@ class PerDeviceDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices( random_ops.random_uniform((10,))) - per_device_dataset = values.PerDeviceDataset( - dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_initializable_iterator() + per_replica_dataset = values.PerReplicaDataset(dataset, devices) + iterator = per_replica_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_element = iterator.get_next() @@ -466,7 +441,7 @@ class PerDeviceDatasetTest(test.TestCase): class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): - def _test_iterator(self, iterator, devices, expected_values): + def _test_iterator(self, sess, iterator, devices, expected_values): next_element = iterator.get_next() for device in devices: v = values.select_device(device, next_element) @@ -475,73 +450,79 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): self.assertTrue(element.device in device) for expected_value in expected_values: - actual = self.evaluate( + actual = sess.run( [values.select_device(d, next_element) for d in devices]) self.assertEqual(expected_value, actual) with self.assertRaises(errors.OutOfRangeError): - self.evaluate([values.select_device(d, next_element) for d in devices]) + sess.run([values.select_device(d, next_element) for d in devices]) - def _test_dataset(self, dataset_fn, worker_device_map, devices, - expected_values): + def _test_dataset(self, dataset_fn, worker_devices, devices, + expected_values, auto_shard=True): multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) - multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() - self._test_iterator(multi_worker_iterator, devices, expected_values) + dataset_fn, worker_devices, auto_shard=auto_shard) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.cached_session() as sess: + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, expected_values) def _cpu_devices(self): - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", - ["/job:worker/replica:0/task:0/device:CPU:0"]), - ("/job:worker/replica:0/task:1", - ["/job:worker/replica:0/task:1/device:CPU:0"])]) + worker_devices = [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] devices = [ "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def _cpu_and_one_gpu_devices(self): - # The worker_device_map doesn't have to be a OrderDict object, this is just - # to simplify the testing so that we can pass expected values as a list - # instead of a dict. - worker_device_map = collections.OrderedDict( - [("/job:worker/replica:0/task:0", [ + worker_devices = [ + ("/job:worker/replica:0/task:0", [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0" - ]), ("/job:worker/replica:0/task:1", [ + ]), + ("/job:worker/replica:0/task:1", [ "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" - ])]) + ]) + ] devices = [ "/job:worker/replica:0/task:0/device:GPU:0", "/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:CPU:0" ] - return worker_device_map, devices + return worker_devices, devices def testDataDistributionOneDevicePerWorker(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) + def testDataDistributionNoAutoShard(self): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_dataset(dataset_fn, worker_devices, devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], + auto_shard=False) + def testDataDistributionTwoDevicePerWorker(self): - self.skipTest("Temporarily disabled.") if context.num_gpus() < 1: self.skipTest("A GPU is not available for this test.") - worker_device_map, devices = self._cpu_and_one_gpu_devices() + worker_devices, devices = self._cpu_and_one_gpu_devices() with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, [[0, 2, 1, 3], [4, 6, 5, 7]]) def testTupleDataset(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() + worker_devices, devices = self._cpu_devices() with context.graph_mode(): @@ -553,47 +534,221 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): expected_values = [ [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) ] - self._test_dataset(dataset_fn, worker_device_map, devices, + self._test_dataset(dataset_fn, worker_devices, devices, expected_values) def testInitializableIterator(self): - self.skipTest("Temporarily disabled.") - worker_device_map, devices = self._cpu_devices() - with context.graph_mode(): + worker_devices, devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) # After re-initializing the iterator, should be able to iterate again. - self.evaluate(multi_worker_iterator.initializer) - self._test_iterator(multi_worker_iterator, devices, + sess.run(multi_worker_iterator.initializer) + self._test_iterator(sess, multi_worker_iterator, devices, [[0, 1], [2, 3], [4, 5], [6, 7]]) def testValueErrorForIterator(self): - self.skipTest("Temporarily disabled.") # Incompatiable arguments. with self.assertRaises(ValueError): values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) # Test duplicated devices under same worker. - worker_device_map, _ = self._cpu_devices() - worker_device_map["/job:worker/replica:0/task:0"].append( - "/job:worker/replica:0/task:0/device:CPU:0") + worker_devices, _ = self._cpu_devices() + worker_devices[0][1].append("/job:worker/replica:0/task:0/device:CPU:0") with context.graph_mode(): dataset_fn = lambda: dataset_ops.Dataset.range(8) multi_worker_dataset = values.MultiWorkerDataset( - dataset_fn, worker_device_map, prefetch_on_device=False) + dataset_fn, worker_devices, auto_shard=True) multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() with self.assertRaises(ValueError): multi_worker_iterator.get_next() -class MirroredVariableTest(test.TestCase): +class InputIteratorTestBase(test.TestCase): + + def _test_iterator(self, input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, split_batch_by=None): + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + + if input_type == "input_fn": + input_contexts = [ + distribute_lib.InputContext() for _ in worker_device_pairs] + input_fn = lambda _: dataset_fn() + iterator = values.InputFunctionIterator(input_fn, worker_device_pairs, + input_contexts) + else: + iterator = values.DatasetIterator(dataset_fn(), worker_device_pairs, + split_batch_by) + + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([values.select_device(d, next_element) for d in devices]) + + # After re-initializing the iterator, should be able to iterate again. + evaluate(control_flow_ops.group(iterator.initialize())) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertAllEqual(expected_value, computed_value) + + +class InputIteratorSingleWorkerTest(InputIteratorTestBase, + parameterized.TestCase): + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"])) + def testOneDeviceCPU(self, input_type): + worker_device_pairs = [("", ["/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesOneGPUOneCPU(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTupleDataset(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + dataset_fn = lambda: dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + input_type=["dataset"], + split_batch_by=[None, 2], + required_gpus=1)) + def testBatchSplitting(self, input_type, split_batch_by): + worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + batch_size = 10 + dataset_fn = lambda: dataset_ops.Dataset.range(100).batch(batch_size) + + updated_batch_size = ( + batch_size // split_batch_by if split_batch_by else batch_size) + expected_values = [[range(i, i+updated_batch_size), + range(i+updated_batch_size, i+2*updated_batch_size)] + for i in range(0, 100, updated_batch_size*2)] + + self._test_iterator(input_type, dataset_fn, worker_device_pairs, + expected_values, sess=None, + split_batch_by=split_batch_by) + + +class InputIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, + parameterized.TestCase): + + def _cpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])] + + def _cpu_and_one_gpu_devices(self): + return [ + ("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), + ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ]) + ] + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testOneDevicePerWorker(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 0], [1, 1], [2, 2], [3, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"], + required_gpus=1)) + def testTwoDevicesPerWorker(self, input_type): + worker_devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(), self.cached_session() as sess: + dataset_fn = lambda: dataset_ops.Dataset.range(4) + self._test_iterator(input_type, dataset_fn, worker_devices, + [[0, 1, 0, 1], [2, 3, 2, 3]], sess) + + @combinations.generate(combinations.combine( + mode=["graph"], + input_type=["input_fn", "dataset"])) + def testTupleDataset(self, input_type): + worker_devices = self._cpu_devices() + with context.graph_mode(), self.cached_session() as sess: + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] + self._test_iterator(input_type, dataset_fn, worker_devices, + expected_values, sess) + + +class MirroredVariableTest(test.TestCase, parameterized.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -605,9 +760,9 @@ class MirroredVariableTest(test.TestCase): v, _, mirrored = _make_mirrored() - self.assertEquals(v[0].name, mirrored.name) - self.assertEquals(v[0].dtype, mirrored.dtype) - self.assertEquals(v[0].shape, mirrored.shape) + self.assertEqual(v[0].name, mirrored.name) + self.assertEqual(v[0].dtype, mirrored.dtype) + self.assertEqual(v[0].shape, mirrored.shape) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): @@ -617,9 +772,9 @@ class MirroredVariableTest(test.TestCase): mirrored = values.MirroredVariable(index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, mirrored.name) - self.assertEquals(v.dtype, mirrored.dtype) - self.assertEquals(v.shape, mirrored.shape) + self.assertEqual(v.name, mirrored.name) + self.assertEqual(v.dtype, mirrored.dtype) + self.assertEqual(v.shape, mirrored.shape) def _assign_mirrored(self, devices, v, new): for d, var, n in zip(devices, v, new): @@ -739,14 +894,13 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testFetchAMirroredVariable(self): - if context.num_gpus() < 1 or context.executing_eagerly(): - self.skipTest("A GPU is not available for this test or it's eager mode.") - - with self.session( - graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( - ["/device:GPU:0"]).scope(): + @combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_one_gpu, + combinations.core_mirrored_strategy_with_one_gpu], + mode=["graph"])) + def testFetchAMirroredVariable(self, distribution): + with self.session(graph=ops.Graph()) as sess, distribution.scope(): with ops.device("/device:GPU:0"): v = variable_scope.get_variable( name="v", initializer=1., use_resource=True) @@ -760,7 +914,7 @@ class MirroredVariableTest(test.TestCase): _devices = ["/device:GPU:0", "/device:CPU:0"] -def _make_tower_local(method): +def _make_replica_local(method): v = [] index = {} for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): @@ -768,11 +922,11 @@ def _make_tower_local(method): v.append(variable_scope.get_variable( name=n, initializer=init, use_resource=True)) index[d] = v[-1] - tower_local = values.TowerLocalVariable(index, v[0], method) - return v, tower_local + replica_local = values.ReplicaLocalVariable(index, v[0], method) + return v, replica_local -class TowerLocalVariableTest(test.TestCase): +class ReplicaLocalVariablePropertiesTest(test.TestCase): config = config_pb2.ConfigProto() config.allow_soft_placement = True @@ -781,30 +935,51 @@ class TowerLocalVariableTest(test.TestCase): def testProperties(self): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) - v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) - - self.assertEquals(v[0].name, tower_local.name) - self.assertEquals(v[0].dtype, tower_local.dtype) - self.assertEquals(v[0].shape, tower_local.shape) - self.assertEquals(variable_scope.VariableAggregation.SUM, - tower_local.aggregation) + self.assertEqual(v[0].name, replica_local.name) + self.assertEqual(v[0].dtype, replica_local.dtype) + self.assertEqual(v[0].shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.SUM, + replica_local.aggregation) @test_util.run_in_graph_and_eager_modes(config=config) def testVariableOnAnotherDevice(self): v = variable_scope.get_variable( name="v", initializer=[1.], use_resource=True) index = {"/job:foo/device:CPU:0": v} - tower_local = values.TowerLocalVariable( + replica_local = values.ReplicaLocalVariable( index, v, variable_scope.VariableAggregation.MEAN) - self.assertEquals(v.name, tower_local.name) - self.assertEquals(v.dtype, tower_local.dtype) - self.assertEquals(v.shape, tower_local.shape) - self.assertEquals(variable_scope.VariableAggregation.MEAN, - tower_local.aggregation) + self.assertEqual(v.name, replica_local.name) + self.assertEqual(v.dtype, replica_local.dtype) + self.assertEqual(v.shape, replica_local.shape) + self.assertEqual(variable_scope.VariableAggregation.MEAN, + replica_local.aggregation) + + def testTensorConversion(self): + with context.graph_mode(): + _, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) + converted = ops.internal_convert_to_tensor(replica_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + + converted = ops.internal_convert_to_tensor(replica_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, replica_local.dtype) + - def _assign_tower_local(self, devices, v, new): +@combinations.generate(combinations.combine( + distribution=[ + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_gpu_and_cpu], + mode=["graph", "eager"])) +class ReplicaLocalVariableTest(test.TestCase, parameterized.TestCase): + + def _assign_replica_local(self, devices, v, new): for d, var, n in zip(devices, v, new): with ops.device(d): self.evaluate(var.assign(n)) @@ -819,86 +994,79 @@ class TowerLocalVariableTest(test.TestCase): save_path, _ = self._save_return_saver(sess, var) return save_path - def _dist_scope(self): - return mirrored_strategy.MirroredStrategy(_devices).scope() - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreTowerLocalSumOneGraph(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - with self.cached_session(config=self.config) as sess: - v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) + def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): + with self.cached_session() as sess: + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) # Overwrite the initial values. - self._assign_tower_local(_devices, v, [3., 4.]) + self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 7. - save_path, saver = self._save_return_saver(sess, tower_local) + save_path, saver = self._save_return_saver(sess, replica_local) # Change the values between save and restore. - self._assign_tower_local(_devices, v, [5., 6.]) + self._assign_replica_local(_devices, v, [5., 6.]) # Restores the saved value of 7. which gets divided equally # between the variables. saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveAndRestoreTowerLocalMeanOneGraph(self): + def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): if context.num_gpus() < 1 and context.executing_eagerly(): self.skipTest("A GPU is not available for this test in eager mode.") - with self.cached_session(config=self.config) as sess: - v, tower_local = _make_tower_local( + with self.cached_session() as sess: + v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. - self._assign_tower_local(_devices, v, [3., 4.]) + self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5. - save_path, saver = self._save_return_saver(sess, tower_local) + save_path, saver = self._save_return_saver(sess, replica_local) # Change the values between save and restore. - self._assign_tower_local(_devices, v, [5., 6.]) + self._assign_replica_local(_devices, v, [5., 6.]) # Restores the saved value of 3.5 to both variables. saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _save_tower_local_mean(self): + def _save_replica_local_mean(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local( + v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. - self._assign_tower_local(_devices, v, [3., 4.]) + self._assign_replica_local(_devices, v, [3., 4.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of (v[0] + v[1])/2, 3.5 - save_path = self._save(sess, tower_local) + save_path = self._save(sess, replica_local) # Change the values between save and restore. - self._assign_tower_local(_devices, v, [5., 6.]) + self._assign_replica_local(_devices, v, [5., 6.]) return save_path - def _save_tower_local_sum(self): + def _save_replica_local_sum(self, distribution): """Save variables with mirroring, returns save_path.""" with self.session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local("sum") + v, replica_local = _make_replica_local("sum") # Overwrite the initial values. - self._assign_tower_local(_devices, v, [1.5, 2.]) + self._assign_replica_local(_devices, v, [1.5, 2.]) - with self._dist_scope(): + with distribution.scope(): # Saves the current value of v[0] + v[1], 3.5 - save_path = self._save(sess, tower_local) + save_path = self._save(sess, replica_local) # Change the values between save and restore. - self._assign_tower_local(_devices, v, [5., 6.]) + self._assign_replica_local(_devices, v, [5., 6.]) return save_path def _save_normal(self): @@ -931,94 +1099,59 @@ class TowerLocalVariableTest(test.TestCase): saver.restore(sess, save_path) self.assertEqual(3.5, self.evaluate(var)) - def _restore_tower_local_mean(self, save_path): + def _restore_replica_local_mean(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local( + v, replica_local = _make_replica_local( variable_scope.VariableAggregation.MEAN) # Overwrite the initial values. - self._assign_tower_local(_devices, v, [7., 8.]) + self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. - saver = saver_lib.Saver(var_list=[tower_local]) + saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) - def _restore_tower_local_sum(self, save_path): + def _restore_replica_local_sum(self, save_path, distribution): """Restore to variables with mirroring in a fresh graph.""" with self.session(graph=ops.Graph()) as sess: - v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) + v, replica_local = _make_replica_local( + variable_scope.VariableAggregation.SUM) # Overwrite the initial values. - self._assign_tower_local(_devices, v, [7., 8.]) + self._assign_replica_local(_devices, v, [7., 8.]) - with self._dist_scope(): + with distribution.scope(): # Restores the saved value of 3.5 to both variables. - saver = saver_lib.Saver(var_list=[tower_local]) + saver = saver_lib.Saver(var_list=[replica_local]) saver.restore(sess, save_path) self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveTowerLocalRestoreTowerLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_tower_local_mean() - self._restore_tower_local_mean(save_path) + def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): + save_path = self._save_replica_local_mean(distribution) + self._restore_replica_local_mean(save_path, distribution) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveTowerLocalRestoreTowerLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_tower_local_sum() - self._restore_tower_local_sum(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveTowerLocalMeanRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): + save_path = self._save_replica_local_sum(distribution) + self._restore_replica_local_sum(save_path, distribution) - save_path = self._save_tower_local_mean() + def testSaveReplicaLocalMeanRestoreNormal(self, distribution): + save_path = self._save_replica_local_mean(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveTowerLocalSumRestoreNormal(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - - save_path = self._save_tower_local_sum() + def testSaveReplicaLocalSumRestoreNormal(self, distribution): + save_path = self._save_replica_local_sum(distribution) self._restore_normal(save_path) - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreTowerLocalMean(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - + def testSaveNormalRestoreReplicaLocalMean(self, distribution): save_path = self._save_normal() - self._restore_tower_local_mean(save_path) - - @test_util.run_in_graph_and_eager_modes(config=config) - def testSaveNormalRestoreTowerLocalSum(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") + self._restore_replica_local_mean(save_path, distribution) + def testSaveNormalRestoreReplicaLocalSum(self, distribution): save_path = self._save_normal() - self._restore_tower_local_sum(save_path) - - def testTensorConversion(self): - with context.graph_mode(): - _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM) - converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, tower_local.dtype) - - converted = ops.internal_convert_to_tensor(tower_local, as_ref=True) - # Resources variable are converted to tensors as well when as_ref is True. - self.assertIsInstance(converted, ops.Tensor) - self.assertEqual(converted.dtype, tower_local.dtype) + self._restore_replica_local_sum(save_path, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/warm_starting_util_test.py b/tensorflow/contrib/distribute/python/warm_starting_util_test.py index 5d57d144c1c16a08280970ecd89eb54f7cf1ffd4..b0bcf9b17456c938204a4892451928daf90b6743 100644 --- a/tensorflow/contrib/distribute/python/warm_starting_util_test.py +++ b/tensorflow/contrib/distribute/python/warm_starting_util_test.py @@ -44,7 +44,9 @@ class WarmStartingUtilWithDistributionStrategyTest( distribution=[combinations.default_strategy, combinations.one_device_strategy, combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], + combinations.mirrored_strategy_with_two_gpus, + combinations.core_mirrored_strategy_with_gpu_and_cpu, + combinations.core_mirrored_strategy_with_two_gpus], save_with_distribution=[True, False], restore_with_distribution=[True, False], mode=["graph"])) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 60f6b90edcb71f04bca29b90744db201e83cd545..3079175015a9aee1625404902070df8f13b2089c 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -72,7 +72,6 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", - "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -80,6 +79,7 @@ py_library( "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/signal", "//third_party/py/numpy", "@six_archive//:six", ], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 5cec93c4df2e970f203253be6342bb292f296eb0..5f6b7fe30996aa97653d97bffb007703437c3d14 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,74 +13,80 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. + +Use [tfp.distributions](/probability/api_docs/python/tfp/distributions) instead. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member +from tensorflow.python.util import deprecation + + +# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member,g-import-not-at-top -from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.contrib.distributions.python.ops.autoregressive import * -from tensorflow.contrib.distributions.python.ops.batch_reshape import * -from tensorflow.contrib.distributions.python.ops.binomial import * -from tensorflow.contrib.distributions.python.ops.cauchy import * -from tensorflow.contrib.distributions.python.ops.chi2 import * -from tensorflow.contrib.distributions.python.ops.conditional_distribution import * -from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * -from tensorflow.contrib.distributions.python.ops.deterministic import * -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular -from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform -from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp -from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse -from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag -from tensorflow.contrib.distributions.python.ops.estimator import * -from tensorflow.contrib.distributions.python.ops.geometric import * -from tensorflow.contrib.distributions.python.ops.half_normal import * -from tensorflow.contrib.distributions.python.ops.independent import * -from tensorflow.contrib.distributions.python.ops.inverse_gamma import * -from tensorflow.contrib.distributions.python.ops.kumaraswamy import * -from tensorflow.contrib.distributions.python.ops.logistic import * -from tensorflow.contrib.distributions.python.ops.mixture import * -from tensorflow.contrib.distributions.python.ops.mixture_same_family import * -from tensorflow.contrib.distributions.python.ops.moving_stats import * -from tensorflow.contrib.distributions.python.ops.mvn_diag import * -from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * -from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * -from tensorflow.contrib.distributions.python.ops.mvn_tril import * -from tensorflow.contrib.distributions.python.ops.negative_binomial import * -from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * -from tensorflow.contrib.distributions.python.ops.onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.poisson import * -from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * -from tensorflow.contrib.distributions.python.ops.quantized_distribution import * -from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * -from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * -from tensorflow.contrib.distributions.python.ops.sample_stats import * -from tensorflow.contrib.distributions.python.ops.seed_stream import * -from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * -from tensorflow.contrib.distributions.python.ops.test_util import * -from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * -from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * -from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * -from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * -from tensorflow.contrib.distributions.python.ops.wishart import * -from tensorflow.python.ops.distributions.bernoulli import * -from tensorflow.python.ops.distributions.beta import * -from tensorflow.python.ops.distributions.categorical import * -from tensorflow.python.ops.distributions.dirichlet import * -from tensorflow.python.ops.distributions.dirichlet_multinomial import * -from tensorflow.python.ops.distributions.distribution import * -from tensorflow.python.ops.distributions.exponential import * -from tensorflow.python.ops.distributions.gamma import * -from tensorflow.python.ops.distributions.kullback_leibler import * -from tensorflow.python.ops.distributions.laplace import * -from tensorflow.python.ops.distributions.multinomial import * -from tensorflow.python.ops.distributions.normal import * -from tensorflow.python.ops.distributions.student_t import * -from tensorflow.python.ops.distributions.transformed_distribution import * -from tensorflow.python.ops.distributions.uniform import * +with deprecation.silence(): + from tensorflow.contrib.distributions.python.ops import bijectors + from tensorflow.contrib.distributions.python.ops.autoregressive import * + from tensorflow.contrib.distributions.python.ops.batch_reshape import * + from tensorflow.contrib.distributions.python.ops.binomial import * + from tensorflow.contrib.distributions.python.ops.cauchy import * + from tensorflow.contrib.distributions.python.ops.chi2 import * + from tensorflow.contrib.distributions.python.ops.conditional_distribution import * + from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * + from tensorflow.contrib.distributions.python.ops.deterministic import * + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular + from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform + from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp + from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse + from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag + from tensorflow.contrib.distributions.python.ops.estimator import * + from tensorflow.contrib.distributions.python.ops.geometric import * + from tensorflow.contrib.distributions.python.ops.half_normal import * + from tensorflow.contrib.distributions.python.ops.independent import * + from tensorflow.contrib.distributions.python.ops.inverse_gamma import * + from tensorflow.contrib.distributions.python.ops.kumaraswamy import * + from tensorflow.contrib.distributions.python.ops.logistic import * + from tensorflow.contrib.distributions.python.ops.mixture import * + from tensorflow.contrib.distributions.python.ops.mixture_same_family import * + from tensorflow.contrib.distributions.python.ops.moving_stats import * + from tensorflow.contrib.distributions.python.ops.mvn_diag import * + from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * + from tensorflow.contrib.distributions.python.ops.mvn_full_covariance import * + from tensorflow.contrib.distributions.python.ops.mvn_tril import * + from tensorflow.contrib.distributions.python.ops.negative_binomial import * + from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import * + from tensorflow.contrib.distributions.python.ops.onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.poisson import * + from tensorflow.contrib.distributions.python.ops.poisson_lognormal import * + from tensorflow.contrib.distributions.python.ops.quantized_distribution import * + from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * + from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * + from tensorflow.contrib.distributions.python.ops.sample_stats import * + from tensorflow.contrib.distributions.python.ops.seed_stream import * + from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * + from tensorflow.contrib.distributions.python.ops.test_util import * + from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * + from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * + from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * + from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * + from tensorflow.contrib.distributions.python.ops.wishart import * + from tensorflow.python.ops.distributions.bernoulli import * + from tensorflow.python.ops.distributions.beta import * + from tensorflow.python.ops.distributions.categorical import * + from tensorflow.python.ops.distributions.dirichlet import * + from tensorflow.python.ops.distributions.dirichlet_multinomial import * + from tensorflow.python.ops.distributions.distribution import * + from tensorflow.python.ops.distributions.exponential import * + from tensorflow.python.ops.distributions.gamma import * + from tensorflow.python.ops.distributions.kullback_leibler import * + from tensorflow.python.ops.distributions.laplace import * + from tensorflow.python.ops.distributions.multinomial import * + from tensorflow.python.ops.distributions.normal import * + from tensorflow.python.ops.distributions.student_t import * + from tensorflow.python.ops.distributions.transformed_distribution import * + from tensorflow.python.ops.distributions.uniform import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 9b9b3ce2dd9d42286d2d9657d5f00de8445261f0..99cb105d66885fd5cf8cb6a3f87e2fe82a5bf4d0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -250,13 +250,22 @@ class DistributionTest(test.TestCase): mvn_dynamic = tfd.MultivariateNormalDiag( loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), name="MVN2") - self.assertEqual( - ("tfp.distributions.MultivariateNormalDiag(" - "\"MVN2/\", " - "batch_shape=(?,), " # Partially known. - "event_shape=(3,), " - "dtype=float32)"), - str(mvn_dynamic)) + if mvn_dynamic.batch_shape._v2_behavior: + self.assertEqual( + ("tfp.distributions.MultivariateNormalDiag(" + "\"MVN2/\", " + "batch_shape=(None,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + else: + self.assertEqual( + ("tfp.distributions.MultivariateNormalDiag(" + "\"MVN2/\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) def testReprWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) @@ -300,13 +309,22 @@ class DistributionTest(test.TestCase): mvn_dynamic = tfd.MultivariateNormalDiag( loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), name="MVN2") - self.assertEqual( - (""), - repr(mvn_dynamic)) + if mvn_dynamic.batch_shape._v2_behavior: + self.assertEqual( + (""), + repr(mvn_dynamic)) + else: + self.assertEqual( + (""), + repr(mvn_dynamic)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py index 29eeaf43c5185ce5519d4a1211f66e99ce61c6ab..ab3c07172a68255f4e387e071ac2f8341e93b90c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py @@ -82,7 +82,7 @@ class NormalTest(test.TestCase): x = constant_op.constant( [[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], [2.5, -2.5, -4.0, 0.0, 1.0, -2.0]], dtype=dtypes.float32) - s = math_ops.reduce_sum(x, reduction_indices=[1]) + s = math_ops.reduce_sum(x, axis=[1]) x = array_ops.transpose(x) # Reshape to shape (6, 2) n = constant_op.constant([6] * 2) prior = distributions.Normal(loc=mu0, scale=sigma0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index a60056c444a3fe7262939c5b3c73673f9a7c1469..cdee30bbc42e661952a9c757d7a30ebcd393f794 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -147,14 +147,13 @@ class WishartCholeskyTest(test.TestCase): x = chol_w.sample(10000, seed=42) self.assertAllEqual((10000, 3, 3), x.get_shape()) - moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval() + moment1_estimate = math_ops.reduce_mean(x, axis=[0]).eval() self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05) # The Variance estimate uses the squares rather than outer-products # because Wishart.Variance is the diagonal of the Wishart covariance # matrix. - variance_estimate = (math_ops.reduce_mean( - math_ops.square(x), reduction_indices=[0]) - + variance_estimate = (math_ops.reduce_mean(math_ops.square(x), axis=[0]) - math_ops.square(moment1_estimate)).eval() self.assertAllClose( chol_w.variance().eval(), variance_estimate, rtol=0.05) diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index 612376efb7f43b0dfcd3ffeb5437f2a419f66f4d..d450379088813caeac6f3dca72fae99c5f886b5a 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -429,5 +429,6 @@ def validate_init_args_statically(distribution, batch_shape): if batch_shape_static.dims is not None: if any( - dim.value is not None and dim.value < 1 for dim in batch_shape_static): + dim.value is not None and + dim.value < 1 for dim in batch_shape_static.dims): raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index e141f8b5c6423bd6cce4d09da6f49d55b3e25a24..3b17de9b8a903956bfdc4d46cf5bbfbfd8530e9f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Bijector Ops. +Use [tfp.bijectors](/probability/api_docs/python/tfp/bijectors) instead. + @@AbsoluteValue @@Affine @@AffineLinearOperator diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 3e1e4fc82971b71792d193ea8518dd402e4a4d9d..2358ef5976b2f21c77130c71d5214a463d17bf0e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -168,11 +168,11 @@ class CholeskyOuterProduct(bijector.Bijector): [is_matrix, is_square, is_positive_definite], x) # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + if x.get_shape().ndims is None or x.get_shape().dims[-1].value is None: p_int = array_ops.shape(x)[-1] p_float = math_ops.cast(p_int, dtype=x.dtype) else: - p_int = x.get_shape()[-1].value + p_int = x.get_shape().dims[-1].value p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) exponents = math_ops.linspace(p_float, 1., p_int) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py index 31a9ca27e519bc312813668bf621a875838f12a0..7ae98878986eb10570b5e93a4a57d6bad6b38c0c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops @@ -104,7 +105,8 @@ class FillTriangular(bijector.Bijector): return array_ops.zeros_like(y[..., 0, 0]) def _forward_event_shape(self, input_shape): - batch_shape, d = input_shape[:-1], input_shape[-1].value + batch_shape, d = (input_shape[:-1], + tensor_shape.dimension_value(input_shape[-1])) if d is None: n = None else: @@ -113,8 +115,8 @@ class FillTriangular(bijector.Bijector): def _inverse_event_shape(self, output_shape): batch_shape, n1, n2 = (output_shape[:-2], - output_shape[-2].value, - output_shape[-1].value) + tensor_shape.dimension_value(output_shape[-2]), + tensor_shape.dimension_value(output_shape[-1])) if n1 is None or n2 is None: m = None elif n1 != n2: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 3b3d8ee6f2dc595983fd5e283d0435e8a227f2ba..c30de1f989a7b83fba1f69a83b96b8f45dea02c6 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import core as layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -237,7 +238,8 @@ class MaskedAutoregressiveFlow(bijector.Bijector): def _forward(self, x): if self._unroll_loop: - event_size = x.shape.with_rank_at_least(1)[-1].value + event_size = tensor_shape.dimension_value( + x.shape.with_rank_at_least(1)[-1]) if event_size is None: raise ValueError( "The final dimension of `x` must be known at graph construction " @@ -260,7 +262,8 @@ class MaskedAutoregressiveFlow(bijector.Bijector): # the graph compiler of the maximum number of steps. If not, # static_event_size will be None, and the maximum_iterations argument will # have no effect. - static_event_size = x.shape.with_rank_at_least(1)[-1].value + static_event_size = tensor_shape.dimension_value( + x.shape.with_rank_at_least(1)[-1]) y0 = array_ops.zeros_like(x, name="y0") # call the template once to ensure creation _ = self._shift_and_log_scale_fn(y0) @@ -405,7 +408,8 @@ def masked_dense(inputs, Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value + input_depth = tensor_shape.dimension_value( + inputs.shape.with_rank_at_least(1)[-1]) if input_depth is None: raise NotImplementedError( "Rightmost dimension must be known prior to graph execution.") @@ -520,7 +524,8 @@ def masked_autoregressive_default_template( def _fn(x): """MADE parameterized via `masked_autoregressive_default_template`.""" # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value + input_depth = tensor_shape.dimension_value( + x.shape.with_rank_at_least(1)[-1]) if input_depth is None: raise NotImplementedError( "Rightmost dimension must be known prior to graph execution.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 0bcb08cdea7142b82af3116245306a11773ef93c..17e9b8dec9f009415a9a26c3b043afacc2c4ec72 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import core as layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -96,16 +97,18 @@ class RealNVP(bijector.Bijector): # A common choice for a normalizing flow is to use a Gaussian for the base # distribution. (However, any continuous distribution would work.) E.g., + num_dims = 3 + num_samples = 1 nvp = tfd.TransformedDistribution( - distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])), + distribution=tfd.MultivariateNormalDiag(loc=np.zeros(num_dims)), bijector=tfb.RealNVP( num_masked=2, shift_and_log_scale_fn=tfb.real_nvp_default_template( hidden_layers=[512, 512]))) - x = nvp.sample() + x = nvp.sample(num_samples) nvp.log_prob(x) - nvp.log_prob(0.) + nvp.log_prob(np.zeros([num_samples, num_dims])) ``` For more examples, see [Jang (2018)][3]. @@ -183,7 +186,8 @@ class RealNVP(bijector.Bijector): def _cache_input_depth(self, x): if self._input_depth is None: - self._input_depth = x.shape.with_rank_at_least(1)[-1].value + self._input_depth = tensor_shape.dimension_value( + x.shape.with_rank_at_least(1)[-1]) if self._input_depth is None: raise NotImplementedError( "Rightmost dimension must be known prior to graph execution.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 71ac29038fc12e7d046df8624c6e3e5bb97d3d8f..ec203e171730a1ef6de6b72c6d96c52d4010d7e6 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -46,7 +47,7 @@ __all__ = [ "instead of `tf.contrib.distributions`.", warn_once=True) def _static_ndims_from_shape(shape): - return shape.shape.with_rank_at_least(1)[0].value + return tensor_shape.dimension_value(shape.shape.with_rank_at_least(1)[0]) @deprecation.deprecated( diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index 20ee0d340833d5c5275e2ab52a89dcdf7198add1..74765f19e584c5de07c6aee4a36ec4e85438f862 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -110,7 +110,7 @@ class SoftmaxCentered(bijector.Bijector): # Set shape hints. if x.shape.ndims is not None: - shape = x.shape[:-1].concatenate(x.shape[-1] + 1) + shape = x.shape[:-1].concatenate(x.shape.dims[-1] + 1) y.shape.assert_is_compatible_with(shape) y.set_shape(shape) @@ -135,7 +135,7 @@ class SoftmaxCentered(bijector.Bijector): # Set shape hints. if y.shape.ndims is not None: - shape = y.shape[:-1].concatenate(y.shape[-1] - 1) + shape = y.shape[:-1].concatenate(y.shape.dims[-1] - 1) x.shape.assert_is_compatible_with(shape) x.set_shape(shape) @@ -168,7 +168,7 @@ class SoftmaxCentered(bijector.Bijector): # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + math_ops.reduce_logsumexp(x, axis=-1, keepdims=True)) return array_ops.squeeze( (-log_normalization + math_ops.reduce_sum( x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index b4ad33cf6dbf073419a27f378c8eefdba97c5af7..1415f85e5cb5598e99c4d6b8e6c6a2d254503db0 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -315,7 +316,7 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): # Static check that event shapes match. if loc is not None: - loc_event_size = loc.get_shape()[-1].value + loc_event_size = tensor_shape.dimension_value(loc.get_shape()[-1]) if loc_event_size is not None and event_size_const is not None: if loc_event_size != 1 and loc_event_size != event_size_const: raise ValueError( diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index e1cfff3c66a2bcbc98af8a257dbdea2d916270e2..cf15deebb78b6c92865c34f61d806bc9e9ab3ee1 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -166,8 +166,10 @@ class Independent(distribution_lib.Distribution): def _batch_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = (batch_shape.shape[0].value - if batch_shape.shape.with_rank_at_least(1)[0].value + dim0 = tensor_shape.dimension_value( + batch_shape.shape.with_rank_at_least(1)[0]) + batch_ndims = (dim0 + if dim0 is not None else array_ops.shape(batch_shape)[0]) return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims] @@ -182,8 +184,10 @@ class Independent(distribution_lib.Distribution): def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = (batch_shape.shape[0].value - if batch_shape.shape.with_rank_at_least(1)[0].value + dim0 = tensor_shape.dimension_value( + batch_shape.shape.with_rank_at_least(1)[0]) + batch_ndims = (dim0 + if dim0 is not None else array_ops.shape(batch_shape)[0]) return array_ops.concat([ batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], @@ -239,9 +243,11 @@ class Independent(distribution_lib.Distribution): static_reinterpreted_batch_ndims, batch_ndims)) elif validate_args: batch_shape = distribution.batch_shape_tensor() + dim0 = tensor_shape.dimension_value( + batch_shape.shape.with_rank_at_least(1)[0]) batch_ndims = ( - batch_shape.shape[0].value - if batch_shape.shape.with_rank_at_least(1)[0].value is not None + dim0 + if dim0 is not None else array_ops.shape(batch_shape)[0]) assertions.append(check_ops.assert_less_equal( reinterpreted_batch_ndims, batch_ndims, diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index f4d394ff29f072907a019afb52bd8dc5d244e955..f34317f5abfed1c71b516c5fde42baca614d7f9b 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -147,8 +148,9 @@ class MixtureSameFamily(distribution.Distribution): self._runtime_assertions = [] s = components_distribution.event_shape_tensor() - self._event_ndims = (s.shape[0].value - if s.shape.with_rank_at_least(1)[0].value is not None + s_dim0 = tensor_shape.dimension_value(s.shape[0]) + self._event_ndims = (s_dim0 + if s_dim0 is not None else array_ops.shape(s)[0]) if not mixture_distribution.dtype.is_integer: @@ -186,8 +188,10 @@ class MixtureSameFamily(distribution.Distribution): "`mixture_distribution.batch_shape` is not " "compatible with `components_distribution.batch_shape`"))] - km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value - kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value + km = tensor_shape.dimension_value( + mixture_distribution.logits.shape.with_rank_at_least(1)[-1]) + kc = tensor_shape.dimension_value( + components_distribution.batch_shape.with_rank_at_least(1)[-1]) if km is not None and kc is not None and km != kc: raise ValueError("`mixture_distribution components` ({}) does not " "equal `components_distribution.batch_shape[-1]` " diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index aa680a92be64cf0f099acd335369f2a1610c5953..978e627d6638ddeea9df288d389354f0ac53d115 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -29,8 +29,8 @@ from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import spectral_ops from tensorflow.python.ops.distributions import util +from tensorflow.python.ops.signal import fft_ops __all__ = [ "auto_correlation", @@ -157,11 +157,11 @@ def auto_correlation( dtype.real_dtype.as_numpy_dtype(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). - fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + fft_x_rotated_pad = fft_ops.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). - shifted_product = spectral_ops.ifft(spectral_density) + shifted_product = fft_ops.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = math_ops.cast(shifted_product, dtype) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index a3d178357b79b9d0d15c738603d5019321eff112..a648d61ac8dd5c1d368cf41505b85827dfeb63e1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -183,7 +183,7 @@ def quadrature_scheme_softmaxnormal_quantiles( def _get_final_shape(qs): """Helper to build `TensorShape`.""" bs = dist.batch_shape.with_rank_at_least(1) - num_components = bs[-1].value + num_components = tensor_shape.dimension_value(bs[-1]) if num_components is not None: num_components += 1 tail = tensor_shape.TensorShape([num_components, qs]) @@ -791,7 +791,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): def _expand_mix_distribution_probs(self): p = self.mixture_distribution.probs # [B, deg] - deg = p.shape.with_rank_at_least(1)[-1].value + deg = tensor_shape.dimension_value(p.shape.with_rank_at_least(1)[-1]) if deg is None: deg = array_ops.shape(p)[-1] event_ndims = self.event_shape.ndims @@ -831,10 +831,12 @@ def maybe_check_quadrature_param(param, name, validate_args): # TODO(jvdillon): Remove once we support k-mixtures. if param.shape.with_rank_at_least(1)[-1] is not None: - if param.shape[-1].value != 1: + if tensor_shape.dimension_value(param.shape[-1]) != 1: raise NotImplementedError("Currently only bimixtures are supported; " "{}.shape[-1]={} is not 1.".format( - name, param.shape[-1].value)) + name, + tensor_shape.dimension_value( + param.shape[-1]))) elif validate_args: assertions.append(check_ops.assert_equal( array_ops.shape(param)[-1], 1, @@ -905,7 +907,7 @@ def interpolate_loc(grid, loc): if len(loc) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(loc))) - deg = grid.shape.with_rank_at_least(1)[-1].value + deg = tensor_shape.dimension_value(grid.shape.with_rank_at_least(1)[-1]) if deg is None: raise ValueError("Num quadrature grid points must be known prior " "to graph execution.") @@ -939,7 +941,7 @@ def interpolate_scale(grid, scale): if len(scale) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - deg = grid.shape.with_rank_at_least(1)[-1].value + deg = tensor_shape.dimension_value(grid.shape.with_rank_at_least(1)[-1]) if deg is None: raise ValueError("Num quadrature grid points must be known prior " "to graph execution.") diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index ee2fc58864d4ac528ebae3d681d2e4922fb60771..2d83f0c13f14a8e5d1eee4fa1436bd05991e934e 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -136,13 +136,13 @@ class _WishartLinearOperator(distribution.Distribution): contrib_tensor_util.assert_same_float_dtype( (self._df, self._scale_operator)) if (self._scale_operator.shape.ndims is None or - self._scale_operator.shape[-1].value is None): + self._scale_operator.shape.dims[-1].value is None): self._dimension = math_ops.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = ops.convert_to_tensor( - self._scale_operator.shape[-1].value, + self._scale_operator.shape.dims[-1].value, dtype=self._scale_operator.dtype, name="dimension") df_val = tensor_util.constant_value(self._df) dim_val = tensor_util.constant_value(self._dimension) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 33a1d572a20e68479d3ec1147d4892449e7beb8a..77052a75a70bec1162feb2b126d247924b3a2e36 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -28,6 +28,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], @@ -249,11 +250,10 @@ py_library( ], ) -py_test( +cuda_py_test( name = "remote_test", srcs = ["remote_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":parameter_server", ":remote", "//tensorflow/contrib/eager/python:tfe", diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 3aed121233be1268531495a2fa83fd323412e1fd..34614b86a75b93ab93cf844c645c211b1329c6d5 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -52,12 +52,6 @@ class Iterator(iterator_ops.EagerIterator): TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ - if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access - raise TypeError( - "`tf.data.experimental.prefetch_to_device()` is not compatible with " - "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " - "over the dataset instead.") - if not context.context().device_spec.device_type: is_remote_device = False else: diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 6a508fc6ba98740c4d441a064dc8a3e2b321f585..257d02057ae0d280074559aa9e97725bf5cc3fd0 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset -from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import unique from tensorflow.python.eager import test @@ -208,18 +207,6 @@ class IteratorTest(test.TestCase): y = math_ops.add(x, x) self.assertAllEqual([0., 2.], y.numpy()) - def testTensorsExplicitPrefetchToDevice(self): - ds = Dataset.from_tensor_slices([0., 1.]) - ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) - - with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'): - datasets.Iterator(ds) - - for i, x in enumerate(ds): - with ops.device(test.gpu_device_name()): - x = math_ops.add(x, x) - self.assertEqual(float(i) + float(i), x.numpy()) - def testOverrideThreadPool(self): def get_thread_id(_): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 7949a3f6da293abdd85512209242bae76ab4d816..51443d24829bdc31a41813e0ff50ad7102422112 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,6 +22,7 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -164,8 +165,8 @@ class Evaluator(object): self.__call__(example, *args, **kwargs) return self.all_metric_results(summary_logdir) # Graph construction - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, - **kwargs) + call_op = self.__call__( + dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs) init_op = self.init_variables() results_op = self.all_metric_results(summary_logdir) return (init_op, call_op, results_op) diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD index 2dc196f550a10367066730f6f042c4ed69533ec3..e2154fcc5fcf774dcd52285d9442dfd5073a4992 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/BUILD +++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_binary") py_binary( name = "densenet", diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py index 4b3cb624bc947a1d1956eff6accb6d4da3bf3b87..24f6b007b526b29157011f3b1e9abdbd50bacc8e 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_graph_test.py @@ -119,7 +119,8 @@ class DensenetBenchmark(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + (images, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = densenet.DenseNet(self.depth, self.growth_rate, self.num_blocks, self.output_classes, diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py index e5058bfd9480e25b3cf040f0d96bf21242a147b8..a9fb0035d299d64b35d756eaf1ae5f7034ff5599 100644 --- a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py +++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py @@ -228,6 +228,7 @@ class DensenetBenchmark(tf.test.Benchmark): weight_decay=1e-4, dropout_rate=0, pool_initial=True, include_top=True) if defun: + # TODO(apassos) enable tfe.function here model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py index 12b39b0cde49d4c017acfa74572c725036c54eff..e73841fbf724e05eaa3be90cc8650f795d3e1ccf 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -42,7 +42,8 @@ class MnistGraphGanBenchmark(tf.test.Benchmark): # Generate some random data. images_data = np.random.randn(batch_size, 784).astype(np.float32) dataset = tf.data.Dataset.from_tensors(images_data) - images = dataset.repeat().make_one_shot_iterator().get_next() + images = tf.compat.v1.data.make_one_shot_iterator( + dataset.repeat()).get_next() # Create the models and optimizers generator = mnist.Generator(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb index ca27a85a229d41a85fa26ecdc982da478fe9e202..1a08cc0fd06516be4af5c2b0b46a3ffcf9101e95 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/cvae.ipynb @@ -470,7 +470,7 @@ "\n", " if epoch % 1 == 0:\n", " loss = tfe.metrics.Mean()\n", - " for test_x in test_dataset.make_one_shot_iterator():\n", + " for test_x in test_dataset:\n", " loss(compute_loss(model, test_x))\n", " elbo = -loss.result()\n", " display.clear_output(wait=False)\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 5621d6a358e8969ea1a6663c1c770987de41ce0c..78fcd397087fd1fd64aebed7ac3b5c6b2f45c450 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,324 +1,405 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "dcgan.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python2", + "display_name": "Python 2" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, + "cell_type": "markdown", "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", + "**Copyright 2018 The TensorFlow Authors**.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\").\n", "\n", - "# DCGAN: An example with tf.keras and eager\n", + "# Generating Handwritten Digits with DCGAN\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ITZuApL56Mny" }, + "cell_type": "markdown", + "source": [ + "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " + ] + }, + { + "metadata": { + "colab_type": "toc", + "id": "x2McrO9bMyLN" + }, + "cell_type": "markdown", + "source": [ + ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", + "\n", + ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", + "\n", + ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", + "\n", + ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", + "\n", + ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", + "\n", + ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", + "\n", + ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", + "\n", + ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", + "\n", + ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", + "\n", + ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", + "\n", + ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", + "\n", + ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", + "\n", + ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", + "\n", + ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "\n", + ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", + "\n" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "2MbKJY38Puy9" + }, + "cell_type": "markdown", "source": [ - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do so, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", + "## What are GANs?\n", + "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", "\n", - "This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.\n", + "![alt text](https://github.com/margaretmz/tensorflow/blob/margaret-dcgan/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png?raw=1)\n", "\n", - "Below is the output generated after training the generator and discriminator models for 150 epochs.\n", + "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", + "\n", + "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", "\n", "![sample output](https://tensorflow.org/images/gan/dcgan.gif)" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "u_2z-B3piVsw" + "id": "u_2z-B3piVsw", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# to generate gifs\n", + "# Install imgeio in order to generate an animated gif showing the image generating process\n", "!pip install imageio" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e1_Y75QXJS6h" }, + "cell_type": "markdown", "source": [ - "## Import TensorFlow and enable eager execution" + "### Import TensorFlow and enable eager execution" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "YfIk2es3hJEd" + "id": "YfIk2es3hJEd", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "tf.enable_eager_execution()\n", "\n", - "import os\n", - "import time\n", - "import numpy as np\n", "import glob\n", + "import imageio\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", "import PIL\n", - "import imageio\n", + "import time\n", + "\n", "from IPython import display" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iYn4MdZnKCey" }, + "cell_type": "markdown", "source": [ - "## Load the dataset\n", + "### Load the dataset\n", "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits." + "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "a4fYMGxGhrna" + "id": "a4fYMGxGhrna", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "NFC2ghIdiZYE" + "id": "NFC2ghIdiZYE", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "# We are normalizing the images to the range of [-1, 1]\n", - "train_images = (train_images - 127.5) / 127.5" - ] + "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "S4PIDhoDLbsZ" + "id": "S4PIDhoDLbsZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = 60000\n", "BATCH_SIZE = 256" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PIGN6ouoQxt3" }, + "cell_type": "markdown", "source": [ - "## Use tf.data to create batches and shuffle the dataset" + "### Use tf.data to create batches and shuffle the dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "-yKCCQOoJ7cn" + "id": "-yKCCQOoJ7cn", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "THY-sZMiQ4UV" }, + "cell_type": "markdown", "source": [ - "## Write the generator and discriminator models\n", + "## Create the models\n", "\n", - "* **Generator** \n", - " * It is responsible for **creating convincing images that are good enough to fool the discriminator**.\n", - " * It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1). \n", - " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", - " \n", - "* **Discriminator**\n", - " * **The discriminator is responsible for classifying the fake images from the real images.**\n", - " * In other words, the discriminator is given generated images (from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake (generated) and real (MNIST images).\n", - " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**." + "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, + "colab_type": "text", + "id": "-tEyxE-GMC48" + }, + "cell_type": "markdown", + "source": [ + "### The Generator Model\n", + "\n", + "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." + ] + }, + { + "metadata": { + "id": "6bpTcDqoLWjY", "colab_type": "code", - "id": "VGLbvBEmjK0a" + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "class Generator(tf.keras.Model):\n", - " def __init__(self):\n", - " super(Generator, self).__init__()\n", - " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n", - " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", - " \n", - " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n", - " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n", - " \n", - " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", - " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n", + "def make_generator_model():\n", + " model = tf.keras.Sequential()\n", + " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " \n", + " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", + " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", " \n", - " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", - "\n", - " def call(self, x, training=True):\n", - " x = self.fc1(x)\n", - " x = self.batchnorm1(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n", + " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", + " assert model.output_shape == (None, 7, 7, 128) \n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", "\n", - " x = self.conv1(x)\n", - " x = self.batchnorm2(x, training=training)\n", - " x = tf.nn.relu(x)\n", + " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", + " assert model.output_shape == (None, 14, 14, 64) \n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", "\n", - " x = self.conv2(x)\n", - " x = self.batchnorm3(x, training=training)\n", - " x = tf.nn.relu(x)\n", + " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", + " assert model.output_shape == (None, 28, 28, 1)\n", + " \n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "D0IKnaCtg6WE" + }, + "cell_type": "markdown", + "source": [ + "### The Discriminator model\n", "\n", - " x = tf.nn.tanh(self.conv3(x)) \n", - " return x" + "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." ] }, { + "metadata": { + "id": "dw2tPLmk2pEP", + "colab_type": "code", + "colab": {} + }, "cell_type": "code", + "source": [ + "def make_discriminator_model():\n", + " model = tf.keras.Sequential()\n", + " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " model.add(tf.keras.layers.Dropout(0.3))\n", + " \n", + " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " model.add(tf.keras.layers.Dropout(0.3))\n", + " \n", + " model.add(tf.keras.layers.Flatten())\n", + " model.add(tf.keras.layers.Dense(1))\n", + " \n", + " return model" + ], "execution_count": 0, + "outputs": [] + }, + { "metadata": { - "colab": {}, "colab_type": "code", - "id": "bkOfJxk5j5Hi" - }, - "outputs": [], - "source": [ - "class Discriminator(tf.keras.Model):\n", - " def __init__(self):\n", - " super(Discriminator, self).__init__()\n", - " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n", - " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n", - " self.dropout = tf.keras.layers.Dropout(0.3)\n", - " self.flatten = tf.keras.layers.Flatten()\n", - " self.fc1 = tf.keras.layers.Dense(1)\n", - "\n", - " def call(self, x, training=True):\n", - " x = tf.nn.leaky_relu(self.conv1(x))\n", - " x = self.dropout(x, training=training)\n", - " x = tf.nn.leaky_relu(self.conv2(x))\n", - " x = self.dropout(x, training=training)\n", - " x = self.flatten(x)\n", - " x = self.fc1(x)\n", - " return x" + "id": "gDkA05NE6QMs", + "colab": {} + }, + "cell_type": "code", + "source": [ + "generator = make_generator_model()\n", + "discriminator = make_discriminator_model()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "0FMYgY_mPfTi" + }, + "cell_type": "markdown", + "source": [ + "## Define the loss functions and the optimizer\n", + "\n", + "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gDkA05NE6QMs" + "colab_type": "text", + "id": "Jd-3GCUEiKtv" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "generator = Generator()\n", - "discriminator = Discriminator()" + "### Generator loss\n", + "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "k1HpMSLImuRi" + "id": "90BIcCKcDMxz", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# Defun gives 10 secs/epoch performance boost\n", - "generator.call = tf.contrib.eager.defun(generator.call)\n", - "discriminator.call = tf.contrib.eager.defun(discriminator.call)" - ] + "def generator_loss(generated_output):\n", + " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "0FMYgY_mPfTi" + "id": "PKY_iPSPNWoj" }, + "cell_type": "markdown", "source": [ - "## Define the loss functions and the optimizer\n", - "\n", - "* **Discriminator loss**\n", - " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", - " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**\n", - " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**\n", - " * Then the total_loss is the sum of real_loss and the generated_loss\n", - " \n", - "* **Generator loss**\n", - " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n", - " \n", + "### Discriminator loss\n", "\n", - "* The discriminator and the generator optimizers are different since we will train them separately." + "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", + "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", + "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", + "3. Calculate the total_loss as the sum of real_loss and generated_loss." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "wkMNfBWlT-PV" + "id": "wkMNfBWlT-PV", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want\n", - " # our generated examples to look like it\n", + " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", "\n", " # [0,0,...,0] with generated images since they are fake\n", @@ -327,55 +408,51 @@ " total_loss = real_loss + generated_loss\n", "\n", " return total_loss" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "90BIcCKcDMxz" + "colab_type": "text", + "id": "MgIc7i0th_Iu" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + "The discriminator and the generator optimizers are different since we will train two networks separately." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "iWCn_PVdEJZ7" + "id": "iWCn_PVdEJZ7", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "generator_optimizer = tf.train.AdamOptimizer(1e-4)" - ] + "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", + "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mWtinsGDPJlV" }, + "cell_type": "markdown", "source": [ - "## Checkpoints (Object-based saving)" + "**Checkpoints (Object-based saving)**" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "CA1w-7s2POEy" + "id": "CA1w-7s2POEy", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", @@ -383,93 +460,85 @@ " discriminator_optimizer=discriminator_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "Rw1fkAczTQYh" + }, + "cell_type": "markdown", + "source": [ + "## Set up GANs for Training\n", + "\n" ] }, { + "metadata": { + "colab_type": "text", + "id": "5QC5BABamh_c" + }, "cell_type": "markdown", + "source": [ + "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." + ] + }, + { "metadata": { "colab_type": "text", - "id": "Rw1fkAczTQYh" + "id": "Ff6oN6PZX27n" }, + "cell_type": "markdown", "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n", - "* The discriminator is given the **real MNIST images as well as the generated images (from the generator)**.\n", - "* Next, we calculate the generator and the discriminator loss.\n", - "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, its time to generate some images!\n", - "* We start by creating noise array as an input to the generator\n", - "* The generator will then convert the noise into handwritten images.\n", - "* Last step is to plot the predictions and **voila!**" + "**Define training parameters**" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "NS2GWywBbAWo" + "id": "NS2GWywBbAWo", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "EPOCHS = 150\n", + "EPOCHS = 50\n", "noise_dim = 100\n", "num_examples_to_generate = 16\n", "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement of the gan.\n", + "# We'll re-use this random vector used to seed the generator so\n", + "# it will be easier to see the improvement over time.\n", "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", " noise_dim])" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "RmdVsmvhPxyy" + "colab_type": "text", + "id": "jylSonrqSWfi" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "def generate_and_save_images(model, epoch, test_input):\n", - " # make sure the training parameter is set to False because we\n", - " # don't want to train the batchnorm layer when doing inference.\n", - " predictions = model(test_input, training=False)\n", + "**Define training method**\n", "\n", - " fig = plt.figure(figsize=(4,4))\n", - " \n", - " for i in range(predictions.shape[0]):\n", - " plt.subplot(4, 4, i+1)\n", - " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", - " plt.axis('off')\n", - " \n", - " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", - " plt.show()" + "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", + "\n", + "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, + "id": "3t5ibNo05jCB", "colab_type": "code", - "id": "2M7LmLtGEMQJ" + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "def train(dataset, epochs, noise_dim): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " # generating noise from a uniform distribution\n", + "def train_step(images):\n", + " # generating noise from a normal distribution\n", " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", " \n", " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", @@ -477,7 +546,7 @@ " \n", " real_output = discriminator(images, training=True)\n", " generated_output = discriminator(generated_images, training=True)\n", - " \n", + " \n", " gen_loss = generator_loss(generated_output)\n", " disc_loss = discriminator_loss(real_output, generated_output)\n", " \n", @@ -485,12 +554,54 @@ " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", " \n", " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "6TSZgwc2BUQ-" + }, + "cell_type": "markdown", + "source": [ "\n", - " \n", - " if epoch % 1 == 0:\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", + "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", + "\n", + "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." + ] + }, + { + "metadata": { + "id": "Iwya07_j5p2A", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "train_step = tf.contrib.eager.defun(train_step)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "code", + "id": "2M7LmLtGEMQJ", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def train(dataset, epochs): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + " \n", + " for images in dataset:\n", + " train_step(images)\n", + "\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", " epoch + 1,\n", " random_vector_for_generation)\n", " \n", @@ -505,111 +616,167 @@ " generate_and_save_images(generator,\n", " epochs,\n", " random_vector_for_generation)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "2aFF7Hk3XdeW" + }, + "cell_type": "markdown", + "source": [ + "**Generate and save images**\n", + "\n" ] }, { + "metadata": { + "colab_type": "code", + "id": "RmdVsmvhPxyy", + "colab": {} + }, "cell_type": "code", + "source": [ + "def generate_and_save_images(model, epoch, test_input):\n", + " # make sure the training parameter is set to False because we\n", + " # don't want to train the batchnorm layer when doing inference.\n", + " predictions = model(test_input, training=False)\n", + "\n", + " fig = plt.figure(figsize=(4,4))\n", + " \n", + " for i in range(predictions.shape[0]):\n", + " plt.subplot(4, 4, i+1)\n", + " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", + " plt.axis('off')\n", + " \n", + " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", + " plt.show()" + ], "execution_count": 0, + "outputs": [] + }, + { "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Ly3UN0SLLY2l" + "colab_type": "text", + "id": "dZrd4CdjR-Fp" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "train(train_dataset, EPOCHS, noise_dim)" + "## Train the GANs\n", + "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", + "\n", + "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." ] }, { - "cell_type": "markdown", + "metadata": { + "colab_type": "code", + "id": "Ly3UN0SLLY2l", + "colab": {} + }, + "cell_type": "code", + "source": [ + "%%time\n", + "train(train_dataset, EPOCHS)" + ], + "execution_count": 0, + "outputs": [] + }, + { "metadata": { "colab_type": "text", "id": "rfM4YcPVPkNO" }, + "cell_type": "markdown", "source": [ - "## Restore the latest checkpoint" + "**Restore the latest checkpoint**" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "XhXsd0srPo8c" + "id": "XhXsd0srPo8c", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "P4M_vIbUi7c0" }, + "cell_type": "markdown", "source": [ - "## Display an image using the epoch number" + "## Generated images \n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WfO5wCdclHGL" + "colab_type": "text", + "id": "mLskt7EfXAjr" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "def display_image(epoch_no):\n", - " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" + "\n", + "After training, its time to generate some images! \n", + "The last step is to plot the generated images and voila!\n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "5x3q9_Oe5q0A" + "id": "WfO5wCdclHGL", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "display_image(EPOCHS)" - ] + "# Display a single image using the epoch number\n", + "def display_image(epoch_no):\n", + " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "NywiH3nL8guF" + "colab_type": "code", + "id": "5x3q9_Oe5q0A", + "colab": {} }, + "cell_type": "code", "source": [ - "## Generate a GIF of all the saved images." - ] + "display_image(EPOCHS)" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "xmO0Dmu2WICn" + "id": "NywiH3nL8guF" }, + "cell_type": "markdown", "source": [ - "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n" + "**Generate a GIF of all the saved images**\n", + "\n", + "We will use imageio to create an animated gif using all the images saved during training." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "IGKQgENQ8lEI" + "id": "IGKQgENQ8lEI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", " filenames = glob.glob('image*.png')\n", @@ -617,7 +784,7 @@ " last = -1\n", " for i,filename in enumerate(filenames):\n", " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", + " if round(frame) > round(last):\n", " last = frame\n", " else:\n", " continue\n", @@ -628,67 +795,84 @@ " \n", "# this is a hack to display the gif inside the notebook\n", "os.system('cp dcgan.gif dcgan.gif.png')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "cGhC3-fMWSwl" + }, + "cell_type": "markdown", + "source": [ + "Display the animated gif with all the mages generated during the training of GANs." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "uV0yiKpzNP1b" + "id": "uV0yiKpzNP1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "display.Image(filename=\"dcgan.gif.png\")" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6EEG-wePkmJQ" }, + "cell_type": "markdown", "source": [ - "To downlod the animation from Colab uncomment the code below:" + "**Download the animated gif**\n", + "\n", + "Uncomment the code below to download an animated gif from Colab." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4UJjSnIMOzOJ" + "id": "4UJjSnIMOzOJ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "#from google.colab import files\n", "#files.download('dcgan.gif')" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "dcgan.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", - "timestamp": 1527173385672 - } ], - "toc_visible": true, - "version": "0.3.2" + "execution_count": 0, + "outputs": [] }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + { + "metadata": { + "colab_type": "text", + "id": "k6qC-SbjK0yW" + }, + "cell_type": "markdown", + "source": [ + "## Learn more about GANs\n" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "xjjkT9KAK6H7" + }, + "cell_type": "markdown", + "source": [ + "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", + "\n", + "To learn more about GANs:\n", + "\n", + "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", + "\n", + "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" + ] } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..b715bd83ef117641c6429e0ac173dbe9b8d5fd88 Binary files /dev/null and b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png differ diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 3acecd283cda83992bab0c37cf0b8037ed2cf27a..12c5eff2b4aa901bdab52bf545e95b1e4dce7468 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1,1184 +1,1174 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image Captioning with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on GitHub
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "QASbY_HGo4Lq" + }, + "source": [ + "Image captioning is the task of generating a caption for an image. Given an image like this:\n", + "\n", + "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", + "\n", + "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", + "\n", + "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", + "\n", + "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", + "\n", + "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", + "\n", + "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", + "\n", + "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", + "\n", + "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { "colab": { - "name": "image_captioning_with_attention.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", - "timestamp": 1530222436922 - } - ], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "accelerator": "GPU" + "colab_type": "code", + "id": "U8l4RJ0XRPEm" + }, + "outputs": [], + "source": [ + "# Import TensorFlow and enable eager execution\n", + "# This code requires TensorFlow version >=1.9\n", + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "\n", + "# We'll generate plots of attention in order to see which parts of an image\n", + "# our model focuses on during captioning\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Scikit-learn includes many helpful utilities\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.utils import shuffle\n", + "\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "import json\n", + "from glob import glob\n", + "from PIL import Image\n", + "import pickle" + ] }, - "cells": [ - { - "metadata": { - "id": "K2s1A9eLRPEj", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\").\n" - ] - }, - { - "metadata": { - "id": "Cffg2i257iMS", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Image Captioning with Attention\n", - "\n", - "
\n", - "\n", - " Run in Google Colab \n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "QASbY_HGo4Lq", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Image captioning is the task of generating a caption for an image. Given an image like this:\n", - "\n", - "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n", - "\n", - "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n", - "\n", - "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n", - "\n", - "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n", - "\n", - "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n", - "\n", - "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n", - "\n", - "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n", - "\n", - "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n" - ] - }, - { - "metadata": { - "id": "U8l4RJ0XRPEm", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Import TensorFlow and enable eager execution\n", - "# This code requires TensorFlow version >=1.9\n", - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "# We'll generate plots of attention in order to see which parts of an image\n", - "# our model focuses on during captioning\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Scikit-learn includes many helpful utilities\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.utils import shuffle\n", - "\n", - "import re\n", - "import numpy as np\n", - "import os\n", - "import time\n", - "import json\n", - "from glob import glob\n", - "from PIL import Image\n", - "import pickle" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "b6qbGw8MRPE5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Download and prepare the MS-COCO dataset\n", - "\n", - "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", - "\n", - "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." - ] - }, - { - "metadata": { - "id": "krQuPYTtRPE7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", - " extract = True)\n", - "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", - "\n", - "name_of_zip = 'train2014.zip'\n", - "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", - " image_zip = tf.keras.utils.get_file(name_of_zip, \n", - " cache_subdir=os.path.abspath('.'),\n", - " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", - " extract = True)\n", - " PATH = os.path.dirname(image_zip)+'/train2014/'\n", - "else:\n", - " PATH = os.path.abspath('.')+'/train2014/'" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "aANEzb5WwSzg", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Optionally, limit the size of the training set for faster training\n", - "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." - ] - }, - { - "metadata": { - "id": "4G3b8x8_RPFD", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# read the json file\n", - "with open(annotation_file, 'r') as f:\n", - " annotations = json.load(f)\n", - "\n", - "# storing the captions and the image name in vectors\n", - "all_captions = []\n", - "all_img_name_vector = []\n", - "\n", - "for annot in annotations['annotations']:\n", - " caption = ' ' + annot['caption'] + ' '\n", - " image_id = annot['image_id']\n", - " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", - " \n", - " all_img_name_vector.append(full_coco_image_path)\n", - " all_captions.append(caption)\n", - "\n", - "# shuffling the captions and image_names together\n", - "# setting a random state\n", - "train_captions, img_name_vector = shuffle(all_captions,\n", - " all_img_name_vector,\n", - " random_state=1)\n", - "\n", - "# selecting the first 30000 captions from the shuffled set\n", - "num_examples = 30000\n", - "train_captions = train_captions[:num_examples]\n", - "img_name_vector = img_name_vector[:num_examples]" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "mPBMgK34RPFL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(train_captions), len(all_captions)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "8cSW4u-ORPFQ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess the images using InceptionV3\n", - "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", - "\n", - "First, we will need to convert the images into the format inceptionV3 expects by:\n", - "* Resizing the image to (299, 299)\n", - "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." - ] - }, - { - "metadata": { - "id": "zXR0217aRPFR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def load_image(image_path):\n", - " img = tf.read_file(image_path)\n", - " img = tf.image.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize_images(img, (299, 299))\n", - " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", - " return img, image_path" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "MDvIu4sXRPFV", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", - "\n", - "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", - "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", - "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", - "* We avoid doing this during training so it does not become a bottleneck. \n", - "* After all the images are passed through the network, we pickle the dictionary and save it to disk." - ] - }, - { - "metadata": { - "id": "RD3vW4SsRPFW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", - " weights='imagenet')\n", - "new_input = image_model.input\n", - "hidden_layer = image_model.layers[-1].output\n", - "\n", - "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "rERqlR3WRPGO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caching the features extracted from InceptionV3\n", - "\n", - "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", - "\n", - "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", - "\n", - "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", - "\n", - "```for img, path in image_dataset:``` \n", - "\n", - "to:\n", - "\n", - "```for img, path in tqdm(image_dataset):```." - ] - }, - { - "metadata": { - "id": "Dx_fvbVgRPGQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# getting the unique images\n", - "encode_train = sorted(set(img_name_vector))\n", - "\n", - "# feel free to change the batch_size according to your system configuration\n", - "image_dataset = tf.data.Dataset.from_tensor_slices(\n", - " encode_train).map(load_image).batch(16)\n", - "\n", - "for img, path in image_dataset:\n", - " batch_features = image_features_extract_model(img)\n", - " batch_features = tf.reshape(batch_features, \n", - " (batch_features.shape[0], -1, batch_features.shape[3]))\n", - "\n", - " for bf, p in zip(batch_features, path):\n", - " path_of_feature = p.numpy().decode(\"utf-8\")\n", - " np.save(path_of_feature, bf.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "nyqH3zFwRPFi", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Preprocess and tokenize the captions\n", - "\n", - "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", - "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", - "* Finally, we create a word --> index mapping and vice-versa.\n", - "* We will then pad all sequences to the be same length as the longest one. " - ] - }, - { - "metadata": { - "id": "HZfK8RhQRPFj", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# This will find the maximum length of any caption in our dataset\n", - "def calc_max_length(tensor):\n", - " return max(len(t) for t in tensor)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "oJGE34aiRPFo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The steps above is a general process of dealing with text processing\n", - "\n", - "# choosing the top 5000 words from the vocabulary\n", - "top_k = 5000\n", - "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", - " oov_token=\"\", \n", - " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", - "tokenizer.fit_on_texts(train_captions)\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "b6qbGw8MRPE5" + }, + "source": [ + "## Download and prepare the MS-COCO dataset\n", + "\n", + "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n", + "\n", + "**Caution: large download ahead**. We'll use the training set, it's a 13GB file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "8Q44tNQVRPFt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n", - "# putting token in the word2idx dictionary\n", - "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n", - "tokenizer.word_index[''] = 0" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "krQuPYTtRPE7" + }, + "outputs": [], + "source": [ + "annotation_zip = tf.keras.utils.get_file('captions.zip', \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n", + " extract = True)\n", + "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n", + "\n", + "name_of_zip = 'train2014.zip'\n", + "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n", + " image_zip = tf.keras.utils.get_file(name_of_zip, \n", + " cache_subdir=os.path.abspath('.'),\n", + " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n", + " extract = True)\n", + " PATH = os.path.dirname(image_zip)+'/train2014/'\n", + "else:\n", + " PATH = os.path.abspath('.')+'/train2014/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "aANEzb5WwSzg" + }, + "source": [ + "## Optionally, limit the size of the training set for faster training\n", + "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "0fpJb5ojRPFv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating the tokenized vectors\n", - "train_seqs = tokenizer.texts_to_sequences(train_captions)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "4G3b8x8_RPFD" + }, + "outputs": [], + "source": [ + "# read the json file\n", + "with open(annotation_file, 'r') as f:\n", + " annotations = json.load(f)\n", + "\n", + "# storing the captions and the image name in vectors\n", + "all_captions = []\n", + "all_img_name_vector = []\n", + "\n", + "for annot in annotations['annotations']:\n", + " caption = ' ' + annot['caption'] + ' '\n", + " image_id = annot['image_id']\n", + " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n", + " \n", + " all_img_name_vector.append(full_coco_image_path)\n", + " all_captions.append(caption)\n", + "\n", + "# shuffling the captions and image_names together\n", + "# setting a random state\n", + "train_captions, img_name_vector = shuffle(all_captions,\n", + " all_img_name_vector,\n", + " random_state=1)\n", + "\n", + "# selecting the first 30000 captions from the shuffled set\n", + "num_examples = 30000\n", + "train_captions = train_captions[:num_examples]\n", + "img_name_vector = img_name_vector[:num_examples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "olQArbgbRPF1", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# creating a reverse mapping (index -> word)\n", - "index_word = {value:key for key, value in tokenizer.word_index.items()}" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "mPBMgK34RPFL" + }, + "outputs": [], + "source": [ + "len(train_captions), len(all_captions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8cSW4u-ORPFQ" + }, + "source": [ + "## Preprocess the images using InceptionV3\n", + "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n", + "\n", + "First, we will need to convert the images into the format inceptionV3 expects by:\n", + "* Resizing the image to (299, 299)\n", + "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AidglIZVRPF4", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# padding each vector to the max_length of the captions\n", - "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", - "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "zXR0217aRPFR" + }, + "outputs": [], + "source": [ + "def load_image(image_path):\n", + " img = tf.read_file(image_path)\n", + " img = tf.image.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize_images(img, (299, 299))\n", + " img = tf.keras.applications.inception_v3.preprocess_input(img)\n", + " return img, image_path" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MDvIu4sXRPFV" + }, + "source": [ + "## Initialize InceptionV3 and load the pretrained Imagenet weights\n", + "\n", + "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n", + "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n", + "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n", + "* We avoid doing this during training so it does not become a bottleneck. \n", + "* After all the images are passed through the network, we pickle the dictionary and save it to disk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "gL0wkttkRPGA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# calculating the max_length \n", - "# used to store the attention weights\n", - "max_length = calc_max_length(train_seqs)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RD3vW4SsRPFW" + }, + "outputs": [], + "source": [ + "image_model = tf.keras.applications.InceptionV3(include_top=False, \n", + " weights='imagenet')\n", + "new_input = image_model.input\n", + "hidden_layer = image_model.layers[-1].output\n", + "\n", + "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "rERqlR3WRPGO" + }, + "source": [ + "## Caching the features extracted from InceptionV3\n", + "\n", + "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n", + "\n", + "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n", + "\n", + "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n", + "\n", + "```for img, path in image_dataset:``` \n", + "\n", + "to:\n", + "\n", + "```for img, path in tqdm(image_dataset):```." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "M3CD75nDpvTI", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Split the data into training and testing" - ] + "colab_type": "code", + "id": "Dx_fvbVgRPGQ" + }, + "outputs": [], + "source": [ + "# getting the unique images\n", + "encode_train = sorted(set(img_name_vector))\n", + "\n", + "# feel free to change the batch_size according to your system configuration\n", + "image_dataset = tf.data.Dataset.from_tensor_slices(\n", + " encode_train).map(load_image).batch(16)\n", + "\n", + "for img, path in image_dataset:\n", + " batch_features = image_features_extract_model(img)\n", + " batch_features = tf.reshape(batch_features, \n", + " (batch_features.shape[0], -1, batch_features.shape[3]))\n", + "\n", + " for bf, p in zip(batch_features, path):\n", + " path_of_feature = p.numpy().decode(\"utf-8\")\n", + " np.save(path_of_feature, bf.numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nyqH3zFwRPFi" + }, + "source": [ + "## Preprocess and tokenize the captions\n", + "\n", + "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n", + "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n", + "* Finally, we create a word --> index mapping and vice-versa.\n", + "* We will then pad all sequences to the be same length as the longest one. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "iS7DDMszRPGF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Create training and validation sets using 80-20 split\n", - "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", - " cap_vector, \n", - " test_size=0.2, \n", - " random_state=0)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "HZfK8RhQRPFj" + }, + "outputs": [], + "source": [ + "# This will find the maximum length of any caption in our dataset\n", + "def calc_max_length(tensor):\n", + " return max(len(t) for t in tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "XmViPkRFRPGH", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "oJGE34aiRPFo" + }, + "outputs": [], + "source": [ + "# The steps above is a general process of dealing with text processing\n", + "\n", + "# choosing the top 5000 words from the vocabulary\n", + "top_k = 5000\n", + "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n", + " oov_token=\"\", \n", + " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n", + "tokenizer.fit_on_texts(train_captions)\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "uEWM9xrYcg45", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", - "\n" - ] + "colab_type": "code", + "id": "8Q44tNQVRPFt" + }, + "outputs": [], + "source": [ + "tokenizer.word_index[''] = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Q3TnZ1ToRPGV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# feel free to change these parameters according to your system's configuration\n", - "\n", - "BATCH_SIZE = 64\n", - "BUFFER_SIZE = 1000\n", - "embedding_dim = 256\n", - "units = 512\n", - "vocab_size = len(tokenizer.word_index)\n", - "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", - "# these two variables represent that\n", - "features_shape = 2048\n", - "attention_features_shape = 64" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "0fpJb5ojRPFv" + }, + "outputs": [], + "source": [ + "# creating the tokenized vectors\n", + "train_seqs = tokenizer.texts_to_sequences(train_captions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "SmZS2N0bXG3T", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# loading the numpy files \n", - "def map_func(img_name, cap):\n", - " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", - " return img_tensor, cap" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AidglIZVRPF4" + }, + "outputs": [], + "source": [ + "# padding each vector to the max_length of the captions\n", + "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n", + "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "FDF_Nm3tRPGZ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", - "\n", - "# using map to load the numpy files in parallel\n", - "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", - "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", - "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", - " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", - "\n", - "# shuffling and batching\n", - "dataset = dataset.shuffle(BUFFER_SIZE)\n", - "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", - "dataset = dataset.batch(BATCH_SIZE)\n", - "dataset = dataset.prefetch(1)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "gL0wkttkRPGA" + }, + "outputs": [], + "source": [ + "# calculating the max_length \n", + "# used to store the attention weights\n", + "max_length = calc_max_length(train_seqs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M3CD75nDpvTI" + }, + "source": [ + "## Split the data into training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "nrvoDphgRPGd", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Model\n", - "\n", - "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", - "\n", - "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", - "\n", - "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", - "* We squash that to a shape of (64, 2048).\n", - "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", - "* The RNN(here GRU) attends over the image to predict the next word." - ] + "colab_type": "code", + "id": "iS7DDMszRPGF" + }, + "outputs": [], + "source": [ + "# Create training and validation sets using 80-20 split\n", + "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n", + " cap_vector, \n", + " test_size=0.2, \n", + " random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AAppCGLKRPGd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def gru(units):\n", - " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", - " # significant speedup).\n", - " if tf.test.is_gpu_available():\n", - " return tf.keras.layers.CuDNNGRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_initializer='glorot_uniform')\n", - " else:\n", - " return tf.keras.layers.GRU(units, \n", - " return_sequences=True, \n", - " return_state=True, \n", - " recurrent_activation='sigmoid', \n", - " recurrent_initializer='glorot_uniform')" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "XmViPkRFRPGH" + }, + "outputs": [], + "source": [ + "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uEWM9xrYcg45" + }, + "source": [ + "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "ja2LFTMSdeV3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class BahdanauAttention(tf.keras.Model):\n", - " def __init__(self, units):\n", - " super(BahdanauAttention, self).__init__()\n", - " self.W1 = tf.keras.layers.Dense(units)\n", - " self.W2 = tf.keras.layers.Dense(units)\n", - " self.V = tf.keras.layers.Dense(1)\n", - " \n", - " def call(self, features, hidden):\n", - " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", - " \n", - " # hidden shape == (batch_size, hidden_size)\n", - " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", - " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", - " \n", - " # score shape == (batch_size, 64, hidden_size)\n", - " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", - " \n", - " # attention_weights shape == (batch_size, 64, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", - " \n", - " # context_vector shape after sum == (batch_size, hidden_size)\n", - " context_vector = attention_weights * features\n", - " context_vector = tf.reduce_sum(context_vector, axis=1)\n", - " \n", - " return context_vector, attention_weights" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Q3TnZ1ToRPGV" + }, + "outputs": [], + "source": [ + "# feel free to change these parameters according to your system's configuration\n", + "\n", + "BATCH_SIZE = 64\n", + "BUFFER_SIZE = 1000\n", + "embedding_dim = 256\n", + "units = 512\n", + "vocab_size = len(tokenizer.word_index)\n", + "# shape of the vector extracted from InceptionV3 is (64, 2048)\n", + "# these two variables represent that\n", + "features_shape = 2048\n", + "attention_features_shape = 64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "AZ7R1RxHRPGf", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class CNN_Encoder(tf.keras.Model):\n", - " # Since we have already extracted the features and dumped it using pickle\n", - " # This encoder passes those features through a Fully connected layer\n", - " def __init__(self, embedding_dim):\n", - " super(CNN_Encoder, self).__init__()\n", - " # shape after fc == (batch_size, 64, embedding_dim)\n", - " self.fc = tf.keras.layers.Dense(embedding_dim)\n", - " \n", - " def call(self, x):\n", - " x = self.fc(x)\n", - " x = tf.nn.relu(x)\n", - " return x" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "SmZS2N0bXG3T" + }, + "outputs": [], + "source": [ + "# loading the numpy files \n", + "def map_func(img_name, cap):\n", + " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n", + " return img_tensor, cap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "V9UbGQmERPGi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class RNN_Decoder(tf.keras.Model):\n", - " def __init__(self, embedding_dim, units, vocab_size):\n", - " super(RNN_Decoder, self).__init__()\n", - " self.units = units\n", - "\n", - " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", - " self.gru = gru(self.units)\n", - " self.fc1 = tf.keras.layers.Dense(self.units)\n", - " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", - " \n", - " self.attention = BahdanauAttention(self.units)\n", - " \n", - " def call(self, x, features, hidden):\n", - " # defining attention as a separate model\n", - " context_vector, attention_weights = self.attention(features, hidden)\n", - " \n", - " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", - " x = self.embedding(x)\n", - " \n", - " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", - " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", - " \n", - " # passing the concatenated vector to the GRU\n", - " output, state = self.gru(x)\n", - " \n", - " # shape == (batch_size, max_length, hidden_size)\n", - " x = self.fc1(output)\n", - " \n", - " # x shape == (batch_size * max_length, hidden_size)\n", - " x = tf.reshape(x, (-1, x.shape[2]))\n", - " \n", - " # output shape == (batch_size * max_length, vocab)\n", - " x = self.fc2(x)\n", - "\n", - " return x, state, attention_weights\n", - "\n", - " def reset_state(self, batch_size):\n", - " return tf.zeros((batch_size, self.units))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "FDF_Nm3tRPGZ" + }, + "outputs": [], + "source": [ + "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n", + "\n", + "# using map to load the numpy files in parallel\n", + "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n", + "# https://www.tensorflow.org/api_docs/python/tf/py_func\n", + "dataset = dataset.map(lambda item1, item2: tf.py_func(\n", + " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n", + "\n", + "# shuffling and batching\n", + "dataset = dataset.shuffle(BUFFER_SIZE)\n", + "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n", + "dataset = dataset.batch(BATCH_SIZE)\n", + "dataset = dataset.prefetch(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nrvoDphgRPGd" + }, + "source": [ + "## Model\n", + "\n", + "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n", + "\n", + "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n", + "\n", + "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n", + "* We squash that to a shape of (64, 2048).\n", + "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n", + "* The RNN(here GRU) attends over the image to predict the next word." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Qs_Sr03wRPGk", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "encoder = CNN_Encoder(embedding_dim)\n", - "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "AAppCGLKRPGd" + }, + "outputs": [], + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n", + " # significant speedup).\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "-bYN7xA0RPGl", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "# We are masking the loss calculated for padding\n", - "def loss_function(real, pred):\n", - " mask = 1 - np.equal(real, 0)\n", - " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", - " return tf.reduce_mean(loss_)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "ja2LFTMSdeV3" + }, + "outputs": [], + "source": [ + "class BahdanauAttention(tf.keras.Model):\n", + " def __init__(self, units):\n", + " super(BahdanauAttention, self).__init__()\n", + " self.W1 = tf.keras.layers.Dense(units)\n", + " self.W2 = tf.keras.layers.Dense(units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, features, hidden):\n", + " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n", + " \n", + " # hidden shape == (batch_size, hidden_size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, 64, hidden_size)\n", + " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, 64, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * features\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " return context_vector, attention_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "PHod7t72RPGn", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Training\n", - "\n", - "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", - "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", - "* The decoder returns the predictions and the decoder hidden state.\n", - "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", - "* Use teacher forcing to decide the next input to the decoder.\n", - "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", - "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" - ] + "colab_type": "code", + "id": "AZ7R1RxHRPGf" + }, + "outputs": [], + "source": [ + "class CNN_Encoder(tf.keras.Model):\n", + " # Since we have already extracted the features and dumped it using pickle\n", + " # This encoder passes those features through a Fully connected layer\n", + " def __init__(self, embedding_dim):\n", + " super(CNN_Encoder, self).__init__()\n", + " # shape after fc == (batch_size, 64, embedding_dim)\n", + " self.fc = tf.keras.layers.Dense(embedding_dim)\n", + " \n", + " def call(self, x):\n", + " x = self.fc(x)\n", + " x = tf.nn.relu(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Vt4WZ5mhJE-E", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# adding this in a separate cell because if you run the training cell \n", - "# many times, the loss_plot array will be reset\n", - "loss_plot = []" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "V9UbGQmERPGi" + }, + "outputs": [], + "source": [ + "class RNN_Decoder(tf.keras.Model):\n", + " def __init__(self, embedding_dim, units, vocab_size):\n", + " super(RNN_Decoder, self).__init__()\n", + " self.units = units\n", + "\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.units)\n", + " self.fc1 = tf.keras.layers.Dense(self.units)\n", + " self.fc2 = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " self.attention = BahdanauAttention(self.units)\n", + " \n", + " def call(self, x, features, hidden):\n", + " # defining attention as a separate model\n", + " context_vector, attention_weights = self.attention(features, hidden)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # shape == (batch_size, max_length, hidden_size)\n", + " x = self.fc1(output)\n", + " \n", + " # x shape == (batch_size * max_length, hidden_size)\n", + " x = tf.reshape(x, (-1, x.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc2(x)\n", + "\n", + " return x, state, attention_weights\n", + "\n", + " def reset_state(self, batch_size):\n", + " return tf.zeros((batch_size, self.units))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "UlA4VIQpRPGo", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "EPOCHS = 20\n", - "\n", - "for epoch in range(EPOCHS):\n", - " start = time.time()\n", - " total_loss = 0\n", - " \n", - " for (batch, (img_tensor, target)) in enumerate(dataset):\n", - " loss = 0\n", - " \n", - " # initializing the hidden state for each batch\n", - " # because the captions are not related from image to image\n", - " hidden = decoder.reset_state(batch_size=target.shape[0])\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", - " \n", - " with tf.GradientTape() as tape:\n", - " features = encoder(img_tensor)\n", - " \n", - " for i in range(1, target.shape[1]):\n", - " # passing the features through the decoder\n", - " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", - "\n", - " loss += loss_function(target[:, i], predictions)\n", - " \n", - " # using teacher forcing\n", - " dec_input = tf.expand_dims(target[:, i], 1)\n", - " \n", - " total_loss += (loss / int(target.shape[1]))\n", - " \n", - " variables = encoder.variables + decoder.variables\n", - " \n", - " gradients = tape.gradient(loss, variables) \n", - " \n", - " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", - " \n", - " if batch % 100 == 0:\n", - " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", - " batch, \n", - " loss.numpy() / int(target.shape[1])))\n", - " # storing the epoch end loss value to plot later\n", - " loss_plot.append(total_loss / len(cap_vector))\n", - " \n", - " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", - " total_loss/len(cap_vector)))\n", - " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "Qs_Sr03wRPGk" + }, + "outputs": [], + "source": [ + "encoder = CNN_Encoder(embedding_dim)\n", + "decoder = RNN_Decoder(embedding_dim, units, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "1Wm83G-ZBPcC", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "plt.plot(loss_plot)\n", - "plt.xlabel('Epochs')\n", - "plt.ylabel('Loss')\n", - "plt.title('Loss Plot')\n", - "plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "-bYN7xA0RPGl" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "# We are masking the loss calculated for padding\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PHod7t72RPGn" + }, + "source": [ + "## Training\n", + "\n", + "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n", + "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n", + "* The decoder returns the predictions and the decoder hidden state.\n", + "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "* Use teacher forcing to decide the next input to the decoder.\n", + "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n", + "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "xGvOcLQKghXN", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Caption!\n", - "\n", - "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", - "* Stop predicting when the model predicts the end token.\n", - "* And store the attention weights for every time step." - ] + "colab_type": "code", + "id": "Vt4WZ5mhJE-E" + }, + "outputs": [], + "source": [ + "# adding this in a separate cell because if you run the training cell \n", + "# many times, the loss_plot array will be reset\n", + "loss_plot = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "RCWpDtyNRPGs", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def evaluate(image):\n", - " attention_plot = np.zeros((max_length, attention_features_shape))\n", - "\n", - " hidden = decoder.reset_state(batch_size=1)\n", - "\n", - " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", - " img_tensor_val = image_features_extract_model(temp_input)\n", - " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", - "\n", - " features = encoder(img_tensor_val)\n", - "\n", - " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", - " result = []\n", - "\n", - " for i in range(max_length):\n", - " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", - "\n", - " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", - "\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", - " result.append(index_word[predicted_id])\n", - "\n", - " if index_word[predicted_id] == '':\n", - " return result, attention_plot\n", - "\n", - " dec_input = tf.expand_dims([predicted_id], 0)\n", - "\n", - " attention_plot = attention_plot[:len(result), :]\n", - " return result, attention_plot" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "UlA4VIQpRPGo" + }, + "outputs": [], + "source": [ + "EPOCHS = 20\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " total_loss = 0\n", + " \n", + " for (batch, (img_tensor, target)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " # initializing the hidden state for each batch\n", + " # because the captions are not related from image to image\n", + " hidden = decoder.reset_state(batch_size=target.shape[0])\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']] * BATCH_SIZE, 1)\n", + " \n", + " with tf.GradientTape() as tape:\n", + " features = encoder(img_tensor)\n", + " \n", + " for i in range(1, target.shape[1]):\n", + " # passing the features through the decoder\n", + " predictions, hidden, _ = decoder(dec_input, features, hidden)\n", + "\n", + " loss += loss_function(target[:, i], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(target[:, i], 1)\n", + " \n", + " total_loss += (loss / int(target.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables) \n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + " \n", + " if batch % 100 == 0:\n", + " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n", + " batch, \n", + " loss.numpy() / int(target.shape[1])))\n", + " # storing the epoch end loss value to plot later\n", + " loss_plot.append(total_loss / len(cap_vector))\n", + " \n", + " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n", + " total_loss/len(cap_vector)))\n", + " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "fD_y7PD6RPGt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def plot_attention(image, result, attention_plot):\n", - " temp_image = np.array(Image.open(image))\n", - "\n", - " fig = plt.figure(figsize=(10, 10))\n", - " \n", - " len_result = len(result)\n", - " for l in range(len_result):\n", - " temp_att = np.resize(attention_plot[l], (8, 8))\n", - " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", - " ax.set_title(result[l])\n", - " img = ax.imshow(temp_image)\n", - " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "1Wm83G-ZBPcC" + }, + "outputs": [], + "source": [ + "plt.plot(loss_plot)\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.title('Loss Plot')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xGvOcLQKghXN" + }, + "source": [ + "## Caption!\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the end token.\n", + "* And store the attention weights for every time step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "io7ws3ReRPGv", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# captions on the validation set\n", - "rid = np.random.randint(0, len(img_name_val))\n", - "image = img_name_val[rid]\n", - "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n", - "result, attention_plot = evaluate(image)\n", - "\n", - "print ('Real Caption:', real_caption)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image, result, attention_plot)\n", - "# opening the image\n", - "Image.open(img_name_val[rid])" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "RCWpDtyNRPGs" + }, + "outputs": [], + "source": [ + "def evaluate(image):\n", + " attention_plot = np.zeros((max_length, attention_features_shape))\n", + "\n", + " hidden = decoder.reset_state(batch_size=1)\n", + "\n", + " temp_input = tf.expand_dims(load_image(image)[0], 0)\n", + " img_tensor_val = image_features_extract_model(temp_input)\n", + " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n", + "\n", + " features = encoder(img_tensor_val)\n", + "\n", + " dec_input = tf.expand_dims([tokenizer.word_index['']], 0)\n", + " result = []\n", + "\n", + " for i in range(max_length):\n", + " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n", + "\n", + " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", + "\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " result.append(tokenizer.index_word[predicted_id])\n", + "\n", + " if tokenizer.index_word[predicted_id] == '':\n", + " return result, attention_plot\n", + "\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " attention_plot = attention_plot[:len(result), :]\n", + " return result, attention_plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "Rprk3HEvZuxb", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Try it on your own images\n", - "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" - ] + "colab_type": "code", + "id": "fD_y7PD6RPGt" + }, + "outputs": [], + "source": [ + "def plot_attention(image, result, attention_plot):\n", + " temp_image = np.array(Image.open(image))\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " \n", + " len_result = len(result)\n", + " for l in range(len_result):\n", + " temp_att = np.resize(attention_plot[l], (8, 8))\n", + " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n", + " ax.set_title(result[l])\n", + " img = ax.imshow(temp_image)\n", + " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n", + "\n", + " plt.tight_layout()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, - { - "metadata": { - "id": "9Psd1quzaAWg", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_extension = image_url[-4:]\n", - "image_path = tf.keras.utils.get_file('image'+image_extension, \n", - " origin=image_url)\n", - "\n", - "result, attention_plot = evaluate(image_path)\n", - "print ('Prediction Caption:', ' '.join(result))\n", - "plot_attention(image_path, result, attention_plot)\n", - "# opening the image\n", - "Image.open(image_path)" - ], - "execution_count": 0, - "outputs": [] + "colab_type": "code", + "id": "io7ws3ReRPGv" + }, + "outputs": [], + "source": [ + "# captions on the validation set\n", + "rid = np.random.randint(0, len(img_name_val))\n", + "image = img_name_val[rid]\n", + "real_caption = ' '.join([tokenizer.index_word[i] for i in cap_val[rid] if i not in [0]])\n", + "result, attention_plot = evaluate(image)\n", + "\n", + "print ('Real Caption:', real_caption)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image, result, attention_plot)\n", + "# opening the image\n", + "Image.open(img_name_val[rid])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Rprk3HEvZuxb" + }, + "source": [ + "## Try it on your own images\n", + "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } }, + "colab_type": "code", + "id": "9Psd1quzaAWg" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_extension = image_url[-4:]\n", + "image_path = tf.keras.utils.get_file('image'+image_extension, \n", + " origin=image_url)\n", + "\n", + "result, attention_plot = evaluate(image_path)\n", + "print ('Prediction Caption:', ' '.join(result))\n", + "plot_attention(image_path, result, attention_plot)\n", + "# opening the image\n", + "Image.open(image_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VJZXyJco6uLO" + }, + "source": [ + "# Next steps\n", + "\n", + "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "image_captioning_with_attention.ipynb", + "private_outputs": true, + "provenance": [ { - "metadata": { - "id": "VJZXyJco6uLO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset." - ] + "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg", + "timestamp": 1530222436922 } - ] + ], + "toc_visible": true, + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index e0d5e494d432b365b0d1dcff6b634de2e6213a43..bda9e77085e45ae31a228142135425e22a1c6780 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -567,7 +567,7 @@ "\n", "* We get predictions using the start_string and the hidden state\n", "\n", - "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", + "* Then we use argmax to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n", "\n", "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n", "\n", @@ -598,19 +598,13 @@ "# empty string to store our results\n", "text_generated = ''\n", "\n", - "# low temperatures results in more predictable text.\n", - "# higher temperatures results in more surprising text\n", - "# experiment to find the best setting\n", - "temperature = 1.0\n", - "\n", "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n", "hidden = [tf.zeros((1, units))]\n", "for i in range(num_generate):\n", " predictions, hidden = model(input_eval, hidden)\n", "\n", - " # using a multinomial distribution to predict the word returned by the model\n", - " predictions = predictions / temperature\n", - " predicted_id = tf.argmax(predictions[0]).numpy()\n", + " # using argmax to predict the word returned by the model\n", + " predicted_id = tf.argmax(predictions[-1]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", @@ -632,7 +626,6 @@ "\n", "* Change the start string to a different character, or the start of a sentence.\n", "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n", - "* Experiment with the temperature parameter.\n", "* Add another RNN layer.\n" ] }, diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/main.py b/tensorflow/contrib/eager/python/examples/l2hmc/main.py index 45e1f98429f48749d374c2aefd8874690c3830ad..98fcb2ba10aa4148dc1d4bd7ddfb6fa9c8c4537c 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/main.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/main.py @@ -71,7 +71,7 @@ def main(_): # Training if FLAGS.use_defun: # Use `tfe.deun` to boost performance when there are lots of small ops - loss_fn = tfe.defun(l2hmc.compute_loss) + loss_fn = tfe.function(l2hmc.compute_loss) else: loss_fn = l2hmc.compute_loss @@ -104,7 +104,7 @@ def main(_): # Evaluation if FLAGS.use_defun: # Use tfe.deun to boost performance when there are lots of small ops - apply_transition = tfe.defun(dynamics.apply_transition) + apply_transition = tfe.function(dynamics.apply_transition) else: apply_transition = dynamics.apply_transition diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py index 557ad42752144243ae3da61b955b31398cba846e..d412b25b368260b81256fd58034330b884261b2b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -36,7 +36,7 @@ class GraphLinearRegressionBenchmark(tf.test.Benchmark): noise_level=0.01, batch_size=batch_size, num_batches=num_batches) - iterator = dataset.make_initializable_iterator() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) x, y = iterator.get_next() model = linear_regression.LinearModel() diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 560fc8c5a22a0e7acf1f37cf7daf7790dc14de19..66d52a74943d0d81fde05ce51b019558b327978d 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -352,7 +352,7 @@ "And the pseudo-code:\n", "\n", "* `score = FC(tanh(FC(EO) + FC(H)))`\n", - "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", + "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, 1)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n", "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n", "* `merged vector = concat(embedding output, context vector)`\n", @@ -446,12 +446,12 @@ " # we are doing this to perform addition to calculate the score\n", " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", " \n", - " # score shape == (batch_size, max_length, hidden_size)\n", - " score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n", + " # score shape == (batch_size, max_length, 1)\n", + " # we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V\n", + " score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))\n", " \n", " # attention_weights shape == (batch_size, max_length, 1)\n", - " # we get 1 at the last axis because we are applying score to self.V\n", - " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " attention_weights = tf.nn.softmax(score, axis=1)\n", " \n", " # context_vector shape after sum == (batch_size, hidden_size)\n", " context_vector = attention_weights * enc_output\n", @@ -768,7 +768,7 @@ }, "outputs": [], "source": [ - "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -781,7 +781,7 @@ }, "outputs": [], "source": [ - "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -794,7 +794,7 @@ }, "outputs": [], "source": [ - "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { @@ -808,7 +808,7 @@ "outputs": [], "source": [ "# wrong translation\n", - "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + "translate(u'trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" ] }, { diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 8fae622e12864ddeee0cedd3cf99be8ea5e4bc48..446e3401184ded6bc34ed64cdd720e29a2851855 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -65,7 +65,7 @@ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n", " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 68a84d5fbb4f13e4ebe0d71e3f5caebe97e2101c..f3135a9668fc0dc7faa93a5f119b53f3efd34c6e 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -35,6 +35,12 @@ cuda_py_test( "//tensorflow/contrib/eager/python:tfe", "//tensorflow:tensorflow_py", ], + tags = [ + "noasan", # Fix b/118130911 + "nomsan", # Fix b/118130911 + "notsan", # Fix b/118130911 + "optonly", + ], ) cuda_py_test( diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index f3bb978875e226f58d6a00e09154191673a97415..fb7975d8fe867711cff31d627788a2d62a520aa9 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -142,7 +142,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): np_images, np_labels = random_batch(batch_size) dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() - (images, labels) = dataset.make_one_shot_iterator().get_next() + images, labels = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() model = resnet50.ResNet50(data_format()) logits = model(images, training=True) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d265169b5eff685f7b79fb221b9bd52be37ead9c..fb81979d7bd8d17a55b8c448008765268dd07d1d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -77,7 +77,7 @@ class ResNet50Test(tf.test.TestCase): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2, data_format) output = model(images, training=False) @@ -221,7 +221,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -266,8 +266,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: - model.call = tfe.defun(model.call) - apply_grads = tfe.defun(apply_gradients) + model.call = tfe.function(model.call) + apply_grads = tfe.function(apply_gradients) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py index b702e91f92220c2a9003a1b82411131332012a9e..9585f3565f83af724b6336e466d3671443ba2361 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -72,14 +72,11 @@ def main(_): train_one_iter(model, x, y, optimizer, global_step=global_step) if global_step.numpy() % config.log_every == 0: - it_test = ds_test.make_one_shot_iterator() - acc_test, loss_test = evaluate(model, it_test) + acc_test, loss_test = evaluate(model, ds_test) if FLAGS.validate: - it_train = ds_train_one_shot.make_one_shot_iterator() - it_validation = ds_validation.make_one_shot_iterator() - acc_train, loss_train = evaluate(model, it_train) - acc_validation, loss_validation = evaluate(model, it_validation) + acc_train, loss_train = evaluate(model, ds_train_one_shot) + acc_validation, loss_validation = evaluate(model, ds_validation) print("Iter {}, " "training set accuracy {:.4f}, loss {:.4f}; " "validation set accuracy {:.4f}, loss {:.4f}; " @@ -218,11 +215,11 @@ def train_one_iter(model, inputs, labels, optimizer, global_step=None): return logits, loss -def evaluate(model, iterator): +def evaluate(model, dataset): """Compute accuracy with the given dataset iterator.""" mean_loss = tfe.metrics.Mean() accuracy = tfe.metrics.Accuracy() - for x, y in iterator: + for x, y in dataset: logits, _ = model(x, training=False) loss = model.compute_loss(logits=logits, labels=y) accuracy( diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py index 3a17eb30da3b989acb0b33f2fcb730da76546c18..125adbb9de6e4febbb4284bfe3a31f257e2e8037 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator.py @@ -173,7 +173,7 @@ def main(_): input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ "image": inputs }) - revnet_estimator.export_savedmodel(FLAGS.model_dir, input_fn) + revnet_estimator.export_saved_model(FLAGS.model_dir, input_fn) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py index 8520cf5b71af503be35d5415707a283fb363a476..b0676916a8da276704de741a50f40cd7d9525228 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py +++ b/tensorflow/contrib/eager/python/examples/revnet/main_estimator_tpu.py @@ -307,7 +307,7 @@ def main(_): # The guide to serve an exported TensorFlow model is at: # https://www.tensorflow.org/serving/serving_basic tf.logging.info("Starting to export model.") - revnet_classifier.export_savedmodel( + revnet_classifier.export_saved_model( export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=imagenet_input.image_serving_input_fn) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 4f4cc3af6f1d5c626b3e2ea7939ecad0ee2d41f1..971aa44f3034692dfb0d03ed3dabf4d6e911eb9f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -127,6 +127,8 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" + # TODO(apassos): make cond support returning None to let this happen with + # tf.function. compute_gradients = tfe.defun(self.model.compute_gradients) _, saved_hidden = self.model(self.x) grads, _ = compute_gradients(saved_hidden=saved_hidden, labels=self.t) @@ -235,6 +237,7 @@ class RevNetBenchmark(tf.test.Benchmark): device, data_format = device_and_format model = revnet.RevNet(config=config) if defun: + # TODO(apassos): reenable after cond lets you return None model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 @@ -282,7 +285,7 @@ class RevNetBenchmark(tf.test.Benchmark): model = revnet.RevNet(config=config) optimizer = tf.train.GradientDescentOptimizer(0.1) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.function(model.call) num_burn = 3 num_iters = 10 diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py index 63b5c4c54d13e9c2448ec1f572ca1389f2443bef..770484abed96e540cf75cc5368a1410c31a8d2d0 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -82,7 +82,7 @@ class PTBBenchmark(tf.test.Benchmark): tf.ones( [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64)).repeat(num_iters + num_warmup) - inputs = dataset.make_one_shot_iterator().get_next() + inputs = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() with tf.device(tf.test.gpu_device_name()): outputs = model(inputs, training=True) @@ -124,7 +124,8 @@ class PTBBenchmark(tf.test.Benchmark): dtype=tf.int64)).repeat(num_iters + num_warmup) # inputs and labels have the same shape dataset = tf.data.Dataset.zip((dataset, dataset)) - (inputs, labels) = dataset.make_one_shot_iterator().get_next() + (inputs, labels) = tf.compat.v1.data.make_one_shot_iterator( + dataset).get_next() with tf.device(tf.test.gpu_device_name()): optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 930e62b68096b468846a01b9674c669a8b8e9a53..566246de4957c1dc5919c10e22146706f9e50be8 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -24,6 +24,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -347,16 +348,17 @@ class Mean(Metric): Raises: ValueError: if the optional argument is not bool """ - # Convert the boolean to tensor for tf.cond, if it is not. + # Convert the boolean to tensor for tf.cond, if it is not. if not isinstance(write_summary, ops.Tensor): write_summary = ops.convert_to_tensor(write_summary) t = self.numer / self.denom def write_summary_f(): summary_ops.scalar(name=self.name, tensor=t) return t - control_flow_ops.cond(write_summary, + smart_cond.smart_cond(write_summary, write_summary_f, - lambda: t) + lambda: t, + name="") return t @@ -487,6 +489,8 @@ class BinaryAccuracy(Mean): message="Shapes of labels and predictions are unequal") predictions = ops.convert_to_tensor(predictions) predictions = predictions > self.threshold + # Convert labels to bool to match predictions. + labels = math_ops.cast(labels, dtypes.bool) matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, self.dtype) super(BinaryAccuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 9d2d172752c7f3f3ee6eaa11ab8952313a4a3543..39e5957f5d1760613f2c33607c0bdb163040efb4 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -49,18 +49,6 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) - def testSummaryArg(self): - m = metrics.Mean() - m([1, 10, 100]) - m(1000) - m([10000.0, 100000.0]) - self.assertEqual(111111.0/6, m.result(write_summary=True).numpy()) - self.assertEqual(111111.0/6, m.result(write_summary=False).numpy()) - with self.assertRaises(ValueError): - m.result(write_summary=5) - with self.assertRaises(ValueError): - m.result(write_summary=[True]) - def testVariableCollections(self): with context.graph_mode(), ops.Graph().as_default(): m = metrics.Mean() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index f801d9a47b2f831a48d9b6335c69612c1356d800..5cc0c4f23d9d641ff1452c7cc9c1fcde612a33a2 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -24,7 +24,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.framework import ops -from tensorflow.python.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -220,7 +220,7 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = keras_base_layer.get_default_graph_uid_map() + name_uid_map = base_layer_utils.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index 3a9e7b027ed68935f2bc0ddbd27a1821a663850d..7803a6799bb64441fab881bf6ca986d5cf3851a8 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -56,12 +56,7 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. # pylint: disable=protected-access - if ops._USE_C_SHAPES: - handle._handle_data = resource_variable_ops.get_resource_handle_data(h) - else: - if h._handle_data is None: - ops.set_shape_and_handle_data_for_outputs(h.op) - handle._handle_data = h._handle_data + handle._handle_data = resource_variable_ops.get_resource_handle_data(h) # pylint: enable=protected-access # Clean up op->graph->op reference cycles. ops.dismantle_graph(graph) diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 7aa4b598b833c3419af501b49f1509d18f3530d5..3926de15e71c9917f88fc3f58740b8c75354ab26 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -206,6 +206,33 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x2) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + @run_sync_and_async + def testContextDeviceUpdated(self): + """Tests that the context device is correctly updated.""" + + with ops.device("cpu:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + + # `y` is placed on the local CPU as expected. + self.assertEqual(y.device, + "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME) + + @run_sync_and_async + def testGPUToRemoteCopy(self): + """Tests that the remote copy happens satisfactorily.""" + if not context.context().num_gpus(): + self.skipTest("No GPUs.") + + x1 = array_ops.ones([2, 2]).gpu() + + with ops.device("/job:remote_device/replica:0/task:1/device:CPU:0"): + x2 = x1._copy() # pylint: disable=protected-access + + np.testing.assert_array_equal(x1.numpy(), x2.numpy()) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index f9c716360c5755ee1902b576545d776725f9966f..1d0d6c6c14ce4a8e454206e0be9fea4724f09192 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -115,6 +115,11 @@ def restore_variables_on_create(save_path, map_func=None): class Saver(object): """A tf.train.Saver adapter for use when eager execution is enabled. + + `Saver`'s name-based checkpointing strategy is fragile. Please switch to + `tf.train.Checkpoint` or `tf.keras.Model.save_weights`, which perform a more + robust object-based saving. These APIs will load checkpoints written by + `Saver`. """ def __init__(self, var_list): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index f5b8d95e4fc7fe5cd90d658eda49590e0b330bb0..33c988fd9065e7fbe7b9aeb85cad82eb3c119f76 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -25,6 +25,7 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@py_func @@defun +@@function @@make_template @@implicit_gradients @@implicit_value_and_gradients @@ -101,7 +102,7 @@ from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver from tensorflow.python.eager import backprop -from tensorflow.python.eager import function +from tensorflow.python.eager import function as _function_lib from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT @@ -115,6 +116,7 @@ from tensorflow.python.eager.context import SYNC from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import set_server_def +from tensorflow.python.eager.def_function import function from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -138,7 +140,7 @@ from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func -defun = function.defun +defun = _function_lib.defun make_template = template.make_template_internal implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 4454abfb9667f824b9de0100bb81bae24ad5f7a6..8c35dddb5a515aa09cc70c173a9f0605e8567e82 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -87,8 +87,8 @@ class TFETest(test_util.TensorFlowTestCase): x += 1. # Without a device context, heuristics are used to place ops. # In this case, ops.reduce_mean runs on the GPU. - reduction_indices = range(x.shape.ndims) - m = math_ops.reduce_mean(x, reduction_indices) + axis = range(x.shape.ndims) + m = math_ops.reduce_mean(x, axis) # m is on GPU, bring it back to CPU and compare. self.assertEqual(3.5, m.cpu().numpy()) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 1ea00fb7f3c6a19824abc8eb80726bb3bba183aa..a888379f13e79d1c246d4cd6d19a225c065692a2 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -8,61 +8,29 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") +# PLACEHOLDER PIP REQUIREMENTS py_library( name = "estimator_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":baseline", ":boosted_trees", - ":dnn", - ":dnn_linear_combined", ":dnn_with_layer_annotations", ":early_stopping", + ":expect_tensorflow_estimator_installed", ":export", ":exporter", ":extenders", ":head", ":hooks", - ":linear", ":logit_fns", ":multi_head", ":replicate_model_fn", ":rnn", ":saved_model_estimator", "//tensorflow:tensorflow_py_no_contrib", - ], -) - -py_library( - name = "baseline", - srcs = ["python/estimator/baseline.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:baseline", - ], -) - -py_test( - name = "baseline_test", - size = "small", - srcs = ["python/estimator/baseline_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - ":baseline", - ":head", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:metric_keys", - "//tensorflow/python/estimator:numpy_io", - "//third_party/py/numpy", - "@six_archive//:six", + "//tensorflow/python/estimator:estimator_py", ], ) @@ -71,67 +39,18 @@ py_library( srcs = ["python/estimator/boosted_trees.py"], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow/python/estimator", "//tensorflow/python/estimator:boosted_trees", ], ) -py_test( - name = "boosted_trees_test", - size = "medium", - srcs = ["python/estimator/boosted_trees_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - ":boosted_trees", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:numpy_io", - "//third_party/py/numpy", - ], -) - -py_library( - name = "dnn", - srcs = ["python/estimator/dnn.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:dnn", - ], -) - -py_test( - name = "dnn_test", - size = "medium", - srcs = ["python/estimator/dnn_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - "optonly", # times out http://b/79220679 - ], - deps = [ - ":dnn", - ":head", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:dnn_testing_utils", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:prediction_keys", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "dnn_with_layer_annotations", srcs = ["python/estimator/dnn_with_layer_annotations.py"], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", @@ -140,64 +59,6 @@ py_library( ], ) -py_test( - name = "dnn_with_layer_annotations_test", - size = "medium", - srcs = ["python/estimator/dnn_with_layer_annotations_test.py"], - shard_count = 4, - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", # b/67510291 - ], - deps = [ - ":dnn_with_layer_annotations", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:dnn", - "//tensorflow/python/estimator:dnn_testing_utils", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:pandas_io", - "//tensorflow/python/estimator:prediction_keys", - "@six_archive//:six", - ], -) - -py_library( - name = "dnn_linear_combined", - srcs = ["python/estimator/dnn_linear_combined.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:dnn_linear_combined", - ], -) - -py_test( - name = "dnn_linear_combined_test", - size = "medium", - srcs = ["python/estimator/dnn_linear_combined_test.py"], - shard_count = 3, - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - ":dnn_linear_combined", - ":head", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:dnn_testing_utils", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:linear_testing_utils", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:prediction_keys", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "extenders", srcs = [ @@ -205,6 +66,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", @@ -213,23 +75,6 @@ py_library( ], ) -py_test( - name = "extenders_test", - size = "medium", - srcs = ["python/estimator/extenders_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], # b/62863147 - deps = [ - ":extenders", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/predictor", - "//tensorflow/python/estimator:estimator_py", - "//tensorflow/python/estimator:linear", - "//third_party/py/numpy", - ], -) - py_library( name = "export", srcs = [ @@ -237,22 +82,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python/estimator:model_fn", - ], -) - -py_test( - name = "export_test", - size = "medium", - srcs = ["python/estimator/export_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], # b/62863147 - deps = [ - ":export", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:export_output", + ":expect_tensorflow_estimator_installed", "//tensorflow/python/estimator:model_fn", ], ) @@ -264,24 +94,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:exporter", ], ) -py_test( - name = "exporter_test", - size = "medium", - srcs = ["python/estimator/exporter_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":exporter", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:exporter", - ], -) - py_library( name = "head", srcs = [ @@ -289,6 +107,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", @@ -298,22 +117,6 @@ py_library( ], ) -py_test( - name = "head_test", - size = "medium", - srcs = ["python/estimator/head_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":head", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:metric_keys", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:prediction_keys", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "hooks", srcs = [ @@ -321,58 +124,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:estimator_py", ], ) -py_test( - name = "hooks_test", - size = "medium", - srcs = ["python/estimator/hooks_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":hooks", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:estimator_py", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "linear", - srcs = ["python/estimator/linear.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:linear", - ], -) - -py_test( - name = "linear_test", - size = "medium", - srcs = ["python/estimator/linear_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - ":head", - ":linear", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:linear_testing_utils", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:prediction_keys", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "logit_fns", srcs = [ @@ -380,24 +137,13 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", ], ) -py_test( - name = "logit_fns_test", - size = "small", - srcs = ["python/estimator/logit_fns_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":logit_fns", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "multi_head", srcs = [ @@ -405,6 +151,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", @@ -414,23 +161,6 @@ py_library( ], ) -py_test( - name = "multi_head_test", - size = "small", - srcs = ["python/estimator/multi_head_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":head", - ":multi_head", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:metric_keys", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:prediction_keys", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "replicate_model_fn", srcs = [ @@ -438,6 +168,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", @@ -446,35 +177,12 @@ py_library( ], ) -cuda_py_test( - name = "replicate_model_fn_test", - size = "medium", - srcs = ["python/estimator/replicate_model_fn_test.py"], - additional_deps = [ - "@absl_py//absl/testing:parameterized", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:dnn", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:export_output", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:optimizers", - "//tensorflow/python/estimator:prediction_keys", - ":replicate_model_fn", - ], - tags = [ - "manual", - "multi_gpu", - "notap", - ], -) - py_library( name = "rnn", srcs = ["python/estimator/rnn.py"], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", ":extenders", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/contrib/feature_column:feature_column_py", @@ -485,55 +193,22 @@ py_library( ], ) -py_test( - name = "rnn_test", - size = "medium", - srcs = ["python/estimator/rnn_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "noasan", # times out - "notsan", - "optonly", # times out http://b/79220679 - ], - deps = [ - ":head", - ":rnn", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/contrib/data", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:parsing_utils", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - py_library( name = "early_stopping", srcs = ["python/estimator/early_stopping.py"], srcs_version = "PY2AND3", deps = [ + ":expect_tensorflow_estimator_installed", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", ], ) -py_test( - name = "early_stopping_test", - srcs = ["python/estimator/early_stopping_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":early_stopping", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "@absl_py//absl/testing:parameterized", - ], -) - py_library( name = "saved_model_estimator", srcs = ["python/estimator/saved_model_estimator.py"], deps = [ + ":expect_tensorflow_estimator_installed", ":export", "//tensorflow:tensorflow_py_no_contrib", "//tensorflow/python/estimator", @@ -542,21 +217,9 @@ py_library( ], ) -py_test( - name = "saved_model_estimator_test", - size = "medium", - srcs = ["python/estimator/saved_model_estimator_test.py"], - srcs_version = "PY2AND3", - tags = [ - "notsan", - ], - deps = [ - ":export", - ":saved_model_estimator", - "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:export_output", - "//tensorflow/python/estimator:model_fn", - ], +py_library( + name = "expect_tensorflow_estimator_installed", + # This is a dummy rule used as a dependency in open-source. + # We expect tensorflow_estimator to already be installed. + visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 419609b1af7b19dc9cf2960e96e71d54d8eb0c9b..7d61247e7ef26d3777843cd3be20684583e9058c 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,33 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental utilities re:tf.estimator.*.""" +"""estimator python module. + +Importing from tensorflow.python.estimator +is unsupported and will soon break! +""" + +# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.estimator.python.estimator.baseline import * -from tensorflow.contrib.estimator.python.estimator.boosted_trees import * -from tensorflow.contrib.estimator.python.estimator.dnn import * -from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import * -from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * -from tensorflow.contrib.estimator.python.estimator.early_stopping import * -from tensorflow.contrib.estimator.python.estimator.export import * -from tensorflow.contrib.estimator.python.estimator.extenders import * -from tensorflow.contrib.estimator.python.estimator.head import * -from tensorflow.contrib.estimator.python.estimator.hooks import * -from tensorflow.contrib.estimator.python.estimator.linear import * -from tensorflow.contrib.estimator.python.estimator.logit_fns import * -from tensorflow.contrib.estimator.python.estimator.multi_head import * -from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * -from tensorflow.contrib.estimator.python.estimator.rnn import * -from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import * -from tensorflow.python.estimator.export.export import * +# Importing from tensorflow.python.estimator +# is unsupported and will soon break! + +from tensorflow_estimator.contrib import estimator + +# Fixes remove_undocumented not working as intended. +# +# Problem is that when the below import happens (for first time, +# Python only imports things once), Python sets attribute named +# 'python' to this package. If this first import happens +# after the call to remove_undocumented, then the 'python' +# attribute won't be removed. +import tensorflow.contrib.estimator.python + +# Include attrs that start with single underscore. +_HAS_DYNAMIC_ATTRIBUTES = True +estimator.__all__ = [s for s in dir(estimator) if not s.startswith('__')] +from tensorflow_estimator.contrib.estimator import * from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ 'add_metrics', @@ -53,10 +58,6 @@ _allowed_symbols = [ 'multi_label_head', 'poisson_regression_head', 'regression_head', - 'BaselineEstimator', - 'DNNEstimator', - 'DNNLinearCombinedEstimator', - 'LinearEstimator', 'boosted_trees_classifier_train_in_memory', 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py deleted file mode 100644 index beffbee73064b9ef425b115317c43e29477b19af..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/baseline.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Baseline estimators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator.canned import baseline - - -class BaselineEstimator(estimator.Estimator): - """An estimator that can establish a simple baseline. - - The estimator uses a user-specified head. - - This estimator ignores feature values and will learn to predict the average - value of each label. E.g. for single-label classification problems, this will - predict the probability distribution of the classes as seen in the labels. - For multi-label classification problems, it will predict the ratio of examples - that contain each class. - - Example: - - ```python - - # Build baseline multi-label classifier. - estimator = BaselineEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3)) - - # Input builders - def input_fn_train: # returns x, y (where y represents label's class index). - pass - - def input_fn_eval: # returns x, y (where y represents label's class index). - pass - - # Fit model. - estimator.train(input_fn=input_fn_train) - - # Evaluates cross entropy between the test and train labels. - loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] - - # For each class, predicts the ratio of training examples that contain the - # class. - predictions = classifier.predict(new_samples) - - ``` - - Input of `train` and `evaluate` should have following features, - otherwise there will be a `KeyError`: - - * if `weight_column` passed to the `head` constructor is not `None`, a feature - with `key=weight_column` whose value is a `Tensor`. - """ - - def __init__(self, - head, - model_dir=None, - optimizer='Ftrl', - config=None): - """Initializes a BaselineEstimator instance. - - Args: - head: A `_Head` instance constructed with a method such as - `tf.contrib.estimator.multi_label_head`. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. - optimizer: String, `tf.Optimizer` object, or callable that creates the - optimizer to use for training. If not specified, will use - `FtrlOptimizer` with a default learning rate of 0.3. - config: `RunConfig` object to configure the runtime settings. - """ - def _model_fn(features, labels, mode, config): - return baseline._baseline_model_fn( # pylint: disable=protected-access - features=features, - labels=labels, - mode=mode, - head=head, - optimizer=optimizer, - config=config) - super(BaselineEstimator, self).__init__( - model_fn=_model_fn, - model_dir=model_dir, - config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py deleted file mode 100644 index 513feb03b6fb7b0806d2a5fb560b1e3394d4094c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for baseline.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np -import six - -from tensorflow.contrib.estimator.python.estimator import baseline -from tensorflow.contrib.estimator.python.estimator import head as head_lib -from tensorflow.python.client import session as tf_session -from tensorflow.python.estimator.canned import metric_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column as feature_column_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import optimizer -from tensorflow.python.training import saver - -# Names of variables created by model. -BIAS_NAME = 'baseline/bias' - - -def assert_close(expected, actual, rtol=1e-04, name='assert_close'): - with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: - expected = ops.convert_to_tensor(expected, name='expected') - actual = ops.convert_to_tensor(actual, name='actual') - rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) - rtol = ops.convert_to_tensor(rtol, name='rtol') - return check_ops.assert_less( - rdiff, - rtol, - data=('Condition expected =~ actual did not hold element-wise:' - 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, - 'rtol = ', rtol,), - name=scope) - - -def save_variables_to_ckpt(model_dir): - init_all_op = [variables.global_variables_initializer()] - with tf_session.Session() as sess: - sess.run(init_all_op) - saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) - - -def _baseline_estimator_fn( - weight_column=None, label_dimension=1, *args, **kwargs): - """Returns a BaselineEstimator that uses regression_head.""" - return baseline.BaselineEstimator( - head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension, - # Tests in core (from which this test inherits) test the sum loss. - loss_reduction=losses.Reduction.SUM), - *args, **kwargs) - - -class BaselineEstimatorEvaluationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def test_evaluation_batch(self): - """Tests evaluation for batch_size==2.""" - with ops.Graph().as_default(): - variables.Variable([13.0], name=BIAS_NAME) - variables.Variable( - 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) - eval_metrics = baseline_estimator.evaluate( - input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) - - # Logit is bias = 13, while label is 10. - # Loss per example is 3**2 = 9. - # Training loss is the sum over batch = 9 + 9 = 18 - # Average loss is the average over batch = 9 - self.assertDictEqual({ - metric_keys.MetricKeys.LOSS: 18., - metric_keys.MetricKeys.LOSS_MEAN: 9., - metric_keys.MetricKeys.PREDICTION_MEAN: 13., - metric_keys.MetricKeys.LABEL_MEAN: 10., - ops.GraphKeys.GLOBAL_STEP: 100 - }, eval_metrics) - - def test_evaluation_weights(self): - """Tests evaluation with weights.""" - with ops.Graph().as_default(): - variables.Variable([13.0], name=BIAS_NAME) - variables.Variable( - 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - def _input_fn(): - features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} - labels = ((10.,), (10.,)) - return features, labels - - baseline_estimator = _baseline_estimator_fn( - weight_column='weights', - model_dir=self._model_dir) - eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) - - # Logit is bias = 13, while label is 10. - # Loss per example is 3**2 = 9. - # Training loss is the weighted sum over batch = 9 + 2*9 = 27 - # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 - self.assertDictEqual({ - metric_keys.MetricKeys.LOSS: 27., - metric_keys.MetricKeys.LOSS_MEAN: 9., - metric_keys.MetricKeys.PREDICTION_MEAN: 13., - metric_keys.MetricKeys.LABEL_MEAN: 10., - ops.GraphKeys.GLOBAL_STEP: 100 - }, eval_metrics) - - def test_evaluation_for_multi_dimensions(self): - label_dim = 2 - with ops.Graph().as_default(): - variables.Variable([46.0, 58.0], name=BIAS_NAME) - variables.Variable(100, name='global_step', dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - baseline_estimator = _baseline_estimator_fn( - label_dimension=label_dim, - model_dir=self._model_dir) - input_fn = numpy_io.numpy_input_fn( - x={ - 'age': np.array([[2., 4., 5.]]), - }, - y=np.array([[46., 58.]]), - batch_size=1, - num_epochs=None, - shuffle=False) - eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) - - self.assertItemsEqual( - (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, - metric_keys.MetricKeys.PREDICTION_MEAN, - metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP), - eval_metrics.keys()) - - # Logit is bias which is [46, 58] - self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) - - -class BaselineEstimatorPredictTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def test_1d(self): - """Tests predict when all variables are one-dimensional.""" - with ops.Graph().as_default(): - variables.Variable([.2], name=BIAS_NAME) - variables.Variable(100, name='global_step', dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) - - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': np.array([[2.]])}, - y=None, - batch_size=1, - num_epochs=1, - shuffle=False) - predictions = baseline_estimator.predict(input_fn=predict_input_fn) - predicted_scores = list([x['predictions'] for x in predictions]) - # x * weight + bias = 2. * 10. + .2 = 20.2 - self.assertAllClose([[.2]], predicted_scores) - - def testMultiDim(self): - """Tests predict when all variables are multi-dimenstional.""" - batch_size = 2 - label_dimension = 3 - with ops.Graph().as_default(): - variables.Variable( # shape=[label_dimension] - [.2, .4, .6], name=BIAS_NAME) - variables.Variable(100, name='global_step', dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - baseline_estimator = _baseline_estimator_fn( - label_dimension=label_dimension, - model_dir=self._model_dir) - - predict_input_fn = numpy_io.numpy_input_fn( - # x shape=[batch_size, x_dim] - x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, - y=None, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - predictions = baseline_estimator.predict(input_fn=predict_input_fn) - predicted_scores = list([x['predictions'] for x in predictions]) - # score = bias, shape=[batch_size, label_dimension] - self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], - predicted_scores) - - -class BaselineEstimatorIntegrationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, - input_dimension, label_dimension, prediction_length): - feature_columns = [ - feature_column_lib.numeric_column('x', shape=(input_dimension,)) - ] - est = _baseline_estimator_fn( - label_dimension=label_dimension, - model_dir=self._model_dir) - - # TRAIN - # learn y = x - est.train(train_input_fn, steps=200) - - # EVALUTE - scores = est.evaluate(eval_input_fn) - self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) - - # PREDICT - predictions = np.array( - [x['predictions'] for x in est.predict(predict_input_fn)]) - self.assertAllEqual((prediction_length, label_dimension), predictions.shape) - - # EXPORT - feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = est.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def test_numpy_input_fn(self): - """Tests complete flow with numpy_input_fn.""" - label_dimension = 2 - input_dimension = label_dimension - batch_size = 10 - prediction_length = batch_size - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - - train_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=None, - batch_size=batch_size, - num_epochs=1, - shuffle=False) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - input_dimension=input_dimension, - label_dimension=label_dimension, - prediction_length=prediction_length) - - -class BaselineEstimatorTrainingTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _mock_optimizer(self, expected_loss=None): - expected_var_names = [ - '%s:0' % BIAS_NAME - ] - - def _minimize(loss, global_step=None, var_list=None): - trainable_vars = var_list or ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual(expected_var_names, - [var.name for var in trainable_vars]) - - # Verify loss. We can't check the value directly, so we add an assert op. - self.assertEquals(0, loss.shape.ndims) - if expected_loss is None: - if global_step is not None: - return state_ops.assign_add(global_step, 1).op - return control_flow_ops.no_op() - assert_loss = assert_close( - math_ops.to_float(expected_loss, name='expected'), - loss, - name='assert_loss') - with ops.control_dependencies((assert_loss,)): - if global_step is not None: - return state_ops.assign_add(global_step, 1).op - return control_flow_ops.no_op() - - mock_optimizer = test.mock.NonCallableMock( - spec=optimizer.Optimizer, - wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) - mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) - - # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. - # So, return mock_optimizer itself for deepcopy. - mock_optimizer.__deepcopy__ = lambda _: mock_optimizer - return mock_optimizer - - def _assert_checkpoint(self, - label_dimension, - expected_global_step, - expected_bias=None): - shapes = { - name: shape - for (name, shape) in checkpoint_utils.list_variables(self._model_dir) - } - - self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) - self.assertEqual(expected_global_step, - checkpoint_utils.load_variable(self._model_dir, - ops.GraphKeys.GLOBAL_STEP)) - - self.assertEqual([label_dimension], shapes[BIAS_NAME]) - if expected_bias is not None: - self.assertEqual(expected_bias, - checkpoint_utils.load_variable(self._model_dir, - BIAS_NAME)) - - def testFromScratch(self): - # Create BaselineRegressor. - label = 5. - age = 17 - # loss = (logits - label)^2 = (0 - 5.)^2 = 25. - mock_optimizer = self._mock_optimizer(expected_loss=25.) - baseline_estimator = _baseline_estimator_fn( - model_dir=self._model_dir, - optimizer=mock_optimizer) - self.assertEqual(0, mock_optimizer.minimize.call_count) - - # Train for a few steps, and validate optimizer and final checkpoint. - num_steps = 10 - baseline_estimator.train( - input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) - self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assert_checkpoint( - label_dimension=1, - expected_global_step=num_steps, - expected_bias=[0.]) - - def testFromCheckpoint(self): - # Create initial checkpoint. - bias = 7.0 - initial_global_step = 100 - with ops.Graph().as_default(): - variables.Variable([bias], name=BIAS_NAME) - variables.Variable( - initial_global_step, - name=ops.GraphKeys.GLOBAL_STEP, - dtype=dtypes.int64) - save_variables_to_ckpt(self._model_dir) - - # logits = bias = 6. - # loss = (logits - label)^2 = (7 - 5)^2 = 4 - mock_optimizer = self._mock_optimizer(expected_loss=4.) - baseline_estimator = _baseline_estimator_fn( - model_dir=self._model_dir, - optimizer=mock_optimizer) - self.assertEqual(0, mock_optimizer.minimize.call_count) - - # Train for a few steps, and validate optimizer and final checkpoint. - num_steps = 10 - baseline_estimator.train( - input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) - self.assertEqual(1, mock_optimizer.minimize.call_count) - self._assert_checkpoint( - label_dimension=1, - expected_global_step=initial_global_step + num_steps, - expected_bias=[bias]) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index a1f1c5f3d7a25ad28c58e9c215b862b6d51f4cd8..4cb66883a50621297518e34bf2c70bbdee146733 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -12,414 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Boosted Trees estimators.""" +"""boosted_trees python module. + +Importing from tensorflow.python.estimator is unsupported +and will soon break! +""" +# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees -from tensorflow.python.estimator.canned import head as head_lib - - -def _validate_input_fn_and_repeat_dataset(train_input_fn): - """Validates whether the input_fn is valid, and repeat() if tf.Dataset.""" - def _input_fn(): - result_input_fn = train_input_fn() - if isinstance(result_input_fn, dataset_ops.Dataset): - return result_input_fn.repeat() - return result_input_fn - - return _input_fn - - -def _is_classification_head(head): - """Infers if the head is a classification head.""" - # Check using all classification heads defined in canned/head.py. However, it - # is not a complete list - it does not check for other classification heads - # not defined in the head library. - # pylint: disable=protected-access - return isinstance(head, - (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss, - head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss)) - # pylint: enable=protected-access - - -class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access - """An Estimator for Tensorflow Boosted Trees models.""" - - def __init__(self, - feature_columns, - n_batches_per_layer, - head, - model_dir=None, - weight_column=None, - n_trees=100, - max_depth=6, - learning_rate=0.1, - l1_regularization=0., - l2_regularization=0., - tree_complexity=0., - min_node_weight=0., - config=None, - center_bias=False, - pruning_mode='none'): - """Initializes a `BoostedTreesEstimator` instance. - - Args: - feature_columns: An iterable containing all the feature columns used by - the model. All items in the set should be instances of classes derived - from `FeatureColumn`. - n_batches_per_layer: the number of batches to collect statistics per - layer. - head: the `Head` instance defined for Estimator. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - weight_column: A string or a `_NumericColumn` created by - `tf.feature_column.numeric_column` defining feature column representing - weights. It is used to downweight or boost examples during training. It - will be multiplied by the loss of the example. If it is a string, it is - used as a key to fetch weight tensor from the `features`. If it is a - `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, - then weight_column.normalizer_fn is applied on it to get weight tensor. - n_trees: number trees to be created. - max_depth: maximum depth of the tree to grow. - learning_rate: shrinkage parameter to be used when a tree added to the - model. - l1_regularization: regularization multiplier applied to the absolute - weights of the tree leafs. - l2_regularization: regularization multiplier applied to the square weights - of the tree leafs. - tree_complexity: regularization factor to penalize trees with more leaves. - min_node_weight: minimum hessian a node must have for a split to be - considered. The value will be compared with sum(leaf_hessian)/ - (batch_size * n_batches_per_layer). - config: `RunConfig` object to configure the runtime settings. - center_bias: Whether bias centering needs to occur. Bias centering refers - to the first node in the very first tree returning the prediction that - is aligned with the original labels distribution. For example, for - regression problems, the first node will return the mean of the labels. - For binary classification problems, it will return a logit for a prior - probability of label 1. - pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- - pruning (do not split a node if not enough gain is observed) and post - pruning (build the tree up to a max depth and then prune branches with - negative gain). For pre and post pruning, you MUST provide - tree_complexity >0. - - Raises: - ValueError: when wrong arguments are given or unsupported functionalities - are requested. - """ - # HParams for the model. - # pylint: disable=protected-access - tree_hparams = canned_boosted_trees._TreeHParams( - n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias, pruning_mode) - - def _model_fn(features, labels, mode, config): - return canned_boosted_trees._bt_model_fn( - features, - labels, - mode, - head, - feature_columns, - tree_hparams, - n_batches_per_layer, - config=config) - - super(_BoostedTreesEstimator, self).__init__( - model_fn=_model_fn, - model_dir=model_dir, - config=config, - feature_columns=feature_columns, - head=head, - center_bias=center_bias, - is_classification=_is_classification_head(head)) - # pylint: enable=protected-access - - -def boosted_trees_classifier_train_in_memory( - train_input_fn, - feature_columns, - model_dir=None, - n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT, - weight_column=None, - label_vocabulary=None, - n_trees=100, - max_depth=6, - learning_rate=0.1, - l1_regularization=0., - l2_regularization=0., - tree_complexity=0., - min_node_weight=0., - config=None, - train_hooks=None, - center_bias=False, - pruning_mode='none'): - """Trains a boosted tree classifier with in memory dataset. - - Example: - - ```python - bucketized_feature_1 = bucketized_column( - numeric_column('feature_1'), BUCKET_BOUNDARIES_1) - bucketized_feature_2 = bucketized_column( - numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - - def train_input_fn(): - dataset = create-dataset-from-training-data - # This is tf.data.Dataset of a tuple of feature dict and label. - # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), - # Dataset.from_tensors(label_array))) - # The returned Dataset shouldn't be batched. - # If Dataset repeats, only the first repetition would be used for training. - return dataset - - classifier = boosted_trees_classifier_train_in_memory( - train_input_fn, - feature_columns=[bucketized_feature_1, bucketized_feature_2], - n_trees=100, - ... - ) - - def input_fn_eval(): - ... - return dataset - - metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10) - ``` - - Args: - train_input_fn: the input function returns a dataset containing a single - epoch of *unbatched* features and labels. - feature_columns: An iterable containing all the feature columns used by - the model. All items in the set should be instances of classes derived - from `FeatureColumn`. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - n_classes: number of label classes. Default is binary classification. - Multiclass support is not yet implemented. - weight_column: A string or a `_NumericColumn` created by - `tf.feature_column.numeric_column` defining feature column representing - weights. It is used to downweight or boost examples during training. It - will be multiplied by the loss of the example. If it is a string, it is - used as a key to fetch weight tensor from the `features`. If it is a - `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, - then weight_column.normalizer_fn is applied on it to get weight tensor. - label_vocabulary: A list of strings represents possible label values. If - given, labels must be string type and have any value in - `label_vocabulary`. If it is not given, that means labels are - already encoded as integer or float within [0, 1] for `n_classes=2` and - encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . - Also there will be errors if vocabulary is not provided and labels are - string. - n_trees: number trees to be created. - max_depth: maximum depth of the tree to grow. - learning_rate: shrinkage parameter to be used when a tree added to the - model. - l1_regularization: regularization multiplier applied to the absolute - weights of the tree leafs. - l2_regularization: regularization multiplier applied to the square weights - of the tree leafs. - tree_complexity: regularization factor to penalize trees with more leaves. - min_node_weight: minimum hessian a node must have for a split to be - considered. The value will be compared with sum(leaf_hessian)/ - (batch_size * n_batches_per_layer). - config: `RunConfig` object to configure the runtime settings. - train_hooks: a list of Hook instances to be passed to estimator.train() - center_bias: Whether bias centering needs to occur. Bias centering refers - to the first node in the very first tree returning the prediction that - is aligned with the original labels distribution. For example, for - regression problems, the first node will return the mean of the labels. - For binary classification problems, it will return a logit for a prior - probability of label 1. - pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- - pruning (do not split a node if not enough gain is observed) and post - pruning (build the tree up to a max depth and then prune branches with - negative gain). For pre and post pruning, you MUST provide - tree_complexity >0. - - Returns: - a `BoostedTreesClassifier` instance created with the given arguments and - trained with the data loaded up on memory from the input_fn. - - Raises: - ValueError: when wrong arguments are given or unsupported functionalities - are requested. - """ - # pylint: disable=protected-access - # TODO(nponomareva): Support multi-class cases. - if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT: - n_classes = 2 - head, closed_form = ( - canned_boosted_trees._create_classification_head_and_closed_form( - n_classes, weight_column, label_vocabulary=label_vocabulary)) - - # HParams for the model. - tree_hparams = canned_boosted_trees._TreeHParams( - n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias, pruning_mode) - - def _model_fn(features, labels, mode, config): - return canned_boosted_trees._bt_model_fn( - features, - labels, - mode, - head, - feature_columns, - tree_hparams, - n_batches_per_layer=1, - config=config, - closed_form_grad_and_hess_fn=closed_form, - train_in_memory=True) - - in_memory_classifier = estimator.Estimator( - model_fn=_model_fn, model_dir=model_dir, config=config) - - in_memory_classifier.train( - input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), - hooks=train_hooks) - - return in_memory_classifier - # pylint: enable=protected-access - - -def boosted_trees_regressor_train_in_memory( - train_input_fn, - feature_columns, - model_dir=None, - label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT, - weight_column=None, - n_trees=100, - max_depth=6, - learning_rate=0.1, - l1_regularization=0., - l2_regularization=0., - tree_complexity=0., - min_node_weight=0., - config=None, - train_hooks=None, - center_bias=False, - pruning_mode='none'): - """Trains a boosted tree regressor with in memory dataset. - - Example: - - ```python - bucketized_feature_1 = bucketized_column( - numeric_column('feature_1'), BUCKET_BOUNDARIES_1) - bucketized_feature_2 = bucketized_column( - numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - - def train_input_fn(): - dataset = create-dataset-from-training-data - # This is tf.data.Dataset of a tuple of feature dict and label. - # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), - # Dataset.from_tensors(label_array))) - # The returned Dataset shouldn't be batched. - # If Dataset repeats, only the first repetition would be used for training. - return dataset - - regressor = boosted_trees_regressor_train_in_memory( - train_input_fn, - feature_columns=[bucketized_feature_1, bucketized_feature_2], - n_trees=100, - ... - ) - - def input_fn_eval(): - ... - return dataset - - metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10) - ``` - - Args: - train_input_fn: the input function returns a dataset containing a single - epoch of *unbatched* features and labels. - feature_columns: An iterable containing all the feature columns used by - the model. All items in the set should be instances of classes derived - from `FeatureColumn`. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - label_dimension: Number of regression targets per example. - Multi-dimensional support is not yet implemented. - weight_column: A string or a `_NumericColumn` created by - `tf.feature_column.numeric_column` defining feature column representing - weights. It is used to downweight or boost examples during training. It - will be multiplied by the loss of the example. If it is a string, it is - used as a key to fetch weight tensor from the `features`. If it is a - `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, - then weight_column.normalizer_fn is applied on it to get weight tensor. - n_trees: number trees to be created. - max_depth: maximum depth of the tree to grow. - learning_rate: shrinkage parameter to be used when a tree added to the - model. - l1_regularization: regularization multiplier applied to the absolute - weights of the tree leafs. - l2_regularization: regularization multiplier applied to the square weights - of the tree leafs. - tree_complexity: regularization factor to penalize trees with more leaves. - min_node_weight: minimum hessian a node must have for a split to be - considered. The value will be compared with sum(leaf_hessian)/ - (batch_size * n_batches_per_layer). - config: `RunConfig` object to configure the runtime settings. - train_hooks: a list of Hook instances to be passed to estimator.train(). - center_bias: Whether bias centering needs to occur. Bias centering refers - to the first node in the very first tree returning the prediction that - is aligned with the original labels distribution. For example, for - regression problems, the first node will return the mean of the labels. - For binary classification problems, it will return a logit for a prior - probability of label 1. - pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre- - pruning (do not split a node if not enough gain is observed) and post - pruning (build the tree up to a max depth and then prune branches with - negative gain). For pre and post pruning, you MUST provide - tree_complexity >0. - - Returns: - a `BoostedTreesClassifier` instance created with the given arguments and - trained with the data loaded up on memory from the input_fn. - - Raises: - ValueError: when wrong arguments are given or unsupported functionalities - are requested. - """ - # pylint: disable=protected-access - # TODO(nponomareva): Extend it to multi-dimension cases. - if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT: - label_dimension = 1 - head = canned_boosted_trees._create_regression_head(label_dimension, - weight_column) - - # HParams for the model. - tree_hparams = canned_boosted_trees._TreeHParams( - n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight, center_bias, pruning_mode) - - def _model_fn(features, labels, mode, config): - return canned_boosted_trees._bt_model_fn( - features, - labels, - mode, - head, - feature_columns, - tree_hparams, - n_batches_per_layer=1, - config=config, - train_in_memory=True) - - in_memory_regressor = estimator.Estimator( - model_fn=_model_fn, model_dir=model_dir, config=config) +from tensorflow_estimator.contrib.estimator.python.estimator import boosted_trees - in_memory_regressor.train( - input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), - hooks=train_hooks) +# Include attrs that start with single underscore. +_HAS_DYNAMIC_ATTRIBUTES = True +boosted_trees.__all__ = [ + s for s in dir(boosted_trees) if not s.startswith('__') +] - return in_memory_regressor - # pylint: enable=protected-access +from tensorflow_estimator.contrib.estimator.python.estimator.boosted_trees import * diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py deleted file mode 100644 index e23d9c0fc4c32ce0ce23dcf4be518577795dd35f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests boosted_trees estimators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.estimator.python.estimator import boosted_trees -from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.platform import googletest -from tensorflow.python.training import checkpoint_utils - -NUM_FEATURES = 3 - -BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features. -INPUT_FEATURES = np.array( - [ - [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1] - [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1] - [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3] - ], - dtype=np.float32) -CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]] -REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]] -FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)} - - -def _make_train_input_fn(is_classification): - """Makes train input_fn for classification/regression.""" - - def _input_fn(): - features_dict = dict(FEATURES_DICT) - labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS - return features_dict, labels - - return _input_fn - - -def _make_train_input_fn_dataset(is_classification): - """Makes input_fn using Dataset.""" - - def _input_fn(): - features_dict = dict(FEATURES_DICT) - labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS - ds = dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(features_dict), - dataset_ops.Dataset.from_tensors(labels) - )) - return ds - - return _input_fn - - -class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._head = canned_boosted_trees._create_regression_head(label_dimension=1) - self._feature_columns = { - feature_column.bucketized_column( - feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), - BUCKET_BOUNDARIES) - for i in range(NUM_FEATURES) - } - - def _assert_checkpoint(self, model_dir, global_step, finalized_trees, - attempted_layers): - reader = checkpoint_utils.load_checkpoint(model_dir) - self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) - serialized = reader.get_tensor('boosted_trees:0_serialized') - ensemble_proto = boosted_trees_pb2.TreeEnsemble() - ensemble_proto.ParseFromString(serialized) - self.assertEqual( - finalized_trees, - sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized])) - self.assertEqual(attempted_layers, - ensemble_proto.growing_metadata.num_layers_attempted) - - def testTrainAndEvaluateEstimator(self): - input_fn = _make_train_input_fn(is_classification=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=2, - head=self._head, - max_depth=5) - - # It will stop after 10 steps because of the max depth and num trees. - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(input_fn, steps=num_steps) - self._assert_checkpoint( - est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) - eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 1.008551) - - def testTrainAndEvaluateEstimatorWithCenterBias(self): - input_fn = _make_train_input_fn(is_classification=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=2, - head=self._head, - max_depth=5, - center_bias=True) - - # It will stop after 11 steps because of the max depth and num trees. - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(input_fn, steps=num_steps) - # 10 steps for training and 2 step for bias centering. - self._assert_checkpoint( - est.model_dir, global_step=12, finalized_trees=2, attempted_layers=10) - eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 0.614642) - - def testTrainAndEvaluateEstimatorWithPrePruning(self): - input_fn = _make_train_input_fn(is_classification=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=2, - head=self._head, - max_depth=5, - tree_complexity=0.001, - pruning_mode='pre') - - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(input_fn, steps=num_steps) - # We stop actually after 2*depth*n_trees steps (via a hook) because we still - # could not grow 2 trees of depth 5 (due to pre-pruning). - self._assert_checkpoint( - est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21) - eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 3.83943) - - def testTrainAndEvaluateEstimatorWithPostPruning(self): - input_fn = _make_train_input_fn(is_classification=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=2, - head=self._head, - max_depth=5, - tree_complexity=0.001, - pruning_mode='post') - - # It will stop after 10 steps because of the max depth and num trees. - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(input_fn, steps=num_steps) - self._assert_checkpoint( - est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) - eval_res = est.evaluate(input_fn=input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 2.37652) - - def testInferEstimator(self): - train_input_fn = _make_train_input_fn(is_classification=False) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=1, - max_depth=5, - head=self._head) - - # It will stop after 5 steps because of the max depth and num trees. - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(train_input_fn, steps=num_steps) - self._assert_checkpoint( - est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Validate predictions. - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose( - [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], - [pred['predictions'] for pred in predictions]) - - def testInferEstimatorWithCenterBias(self): - train_input_fn = _make_train_input_fn(is_classification=False) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=1, - max_depth=5, - center_bias=True, - head=self._head) - - # It will stop after 6 steps because of the max depth and num trees (5 for - # training and 2 for bias centering). - num_steps = 100 - # Train for a few steps, and validate final checkpoint. - est.train(train_input_fn, steps=num_steps) - self._assert_checkpoint( - est.model_dir, global_step=7, finalized_trees=1, attempted_layers=5) - # Validate predictions. - predictions = list(est.predict(input_fn=predict_input_fn)) - - self.assertAllClose( - [[1.634501], [1.325703], [1.187431], [2.019683], [2.832683]], - [pred['predictions'] for pred in predictions]) - - def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): - train_input_fn = _make_train_input_fn(is_classification=True) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) - # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint( - est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['accuracy'], 1.0) - # Validate predictions. - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose([[0], [1], [1], [0], [0]], - [pred['class_ids'] for pred in predictions]) - - def testBinaryClassifierTrainInMemoryAndEvalAndInferWithCenterBias(self): - train_input_fn = _make_train_input_fn(is_classification=True) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5, - center_bias=True) - # It will stop after 5 steps + 3 for bias, because of the max depth and num - # trees. - self._assert_checkpoint( - est.model_dir, global_step=8, finalized_trees=1, attempted_layers=5) - - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['accuracy'], 1.0) - # Validate predictions. - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose([[0], [1], [1], [0], [0]], - [pred['class_ids'] for pred in predictions]) - - def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self): - train_input_fn = _make_train_input_fn(is_classification=True) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5, - pruning_mode='pre', - tree_complexity=0.01) - # We stop actually after 2*depth*n_trees steps (via a hook) because we still - # could not grow 1 trees of depth 5 (due to pre-pruning). - self._assert_checkpoint( - est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11) - - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['accuracy'], 1.0) - # Validate predictions. - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose([[0], [1], [1], [0], [0]], - [pred['class_ids'] for pred in predictions]) - - def testBinaryClassifierTrainInMemoryWithDataset(self): - train_input_fn = _make_train_input_fn_dataset(is_classification=True) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5) - # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint( - est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['accuracy'], 1.0) - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose([[0], [1], [1], [0], [0]], - [pred['class_ids'] for pred in predictions]) - - def testRegressorTrainInMemoryAndEvalAndInfer(self): - train_input_fn = _make_train_input_fn(is_classification=False) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_regressor_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) - # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint( - est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 2.478283) - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose( - [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], - [pred['predictions'] for pred in predictions]) - - def testRegressorTrainInMemoryWithDataset(self): - train_input_fn = _make_train_input_fn_dataset(is_classification=False) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees.boosted_trees_regressor_train_in_memory( - train_input_fn=train_input_fn, feature_columns=self._feature_columns, - n_trees=1, max_depth=5) - # It will stop after 5 steps because of the max depth and num trees. - self._assert_checkpoint( - est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Check evaluate and predict. - eval_res = est.evaluate(input_fn=train_input_fn, steps=1) - self.assertAllClose(eval_res['average_loss'], 2.478283) - predictions = list(est.predict(input_fn=predict_input_fn)) - self.assertAllClose( - [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], - [pred['predictions'] for pred in predictions]) - - -class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase): - - def setUp(self): - self._head = canned_boosted_trees._create_regression_head(label_dimension=1) - self._feature_columns = { - feature_column.bucketized_column( - feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), - BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) - } - - def testContribEstimatorThatDFCIsInPredictions(self): - # pylint:disable=protected-access - head = canned_boosted_trees._create_regression_head(label_dimension=1) - train_input_fn = _make_train_input_fn(is_classification=False) - predict_input_fn = numpy_io.numpy_input_fn( - x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) - - est = boosted_trees._BoostedTreesEstimator( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - head=head, - n_trees=1, - max_depth=5, - center_bias=True) - # pylint:enable=protected-access - - num_steps = 100 - # Train for a few steps. Validate debug outputs in prediction dicts. - est.train(train_input_fn, steps=num_steps) - debug_predictions = est.experimental_predict_with_explanations( - predict_input_fn) - biases, dfcs = zip(*[(pred['bias'], pred['dfc']) - for pred in debug_predictions]) - self.assertAllClose([1.8] * 5, biases) - self.assertAllClose(({ - 0: -0.070499420166015625, - 1: -0.095000028610229492, - 2: 0.0 - }, { - 0: -0.53763031959533691, - 1: 0.063333392143249512, - 2: 0.0 - }, { - 0: -0.51756942272186279, - 1: -0.095000028610229492, - 2: 0.0 - }, { - 0: 0.1563495397567749, - 1: 0.063333392143249512, - 2: 0.0 - }, { - 0: 0.96934974193572998, - 1: 0.063333392143249512, - 2: 0.0 - }), dfcs) - - # Assert sum(dfcs) + bias == predictions. - expected_predictions = [[1.6345005], [1.32570302], [1.1874305], - [2.01968288], [2.83268309]] - predictions = [ - [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases) - ] - self.assertAllClose(expected_predictions, predictions) - - # Test when user doesn't include bias or dfc in predict_keys. - debug_predictions = est.experimental_predict_with_explanations( - predict_input_fn, predict_keys=['predictions']) - for prediction_dict in debug_predictions: - self.assertTrue('bias' in prediction_dict) - self.assertTrue('dfc' in prediction_dict) - self.assertTrue('predictions' in prediction_dict) - self.assertEqual(len(prediction_dict), 3) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py deleted file mode 100644 index 9efa8f474d865a36788cba40a15404bf0b30a17e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Deep Neural Network estimators.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator.canned import dnn as dnn_lib -from tensorflow.python.ops import nn - - -class DNNEstimator(estimator.Estimator): - """An estimator for TensorFlow DNN models with user-specified head. - - Example: - - ```python - sparse_feature_a = sparse_column_with_hash_bucket(...) - sparse_feature_b = sparse_column_with_hash_bucket(...) - - sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a, - ...) - sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b, - ...) - - estimator = DNNEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3), - feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], - hidden_units=[1024, 512, 256]) - - # Or estimator using the ProximalAdagradOptimizer optimizer with - # regularization. - estimator = DNNEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3), - feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], - hidden_units=[1024, 512, 256], - optimizer=tf.train.ProximalAdagradOptimizer( - learning_rate=0.1, - l1_regularization_strength=0.001 - )) - - # Or estimator using an optimizer with a learning rate decay. - estimator = DNNEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3), - feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], - hidden_units=[1024, 512, 256], - optimizer=lambda: tf.AdamOptimizer( - learning_rate=tf.exponential_decay( - learning_rate=0.1, - global_step=tf.get_global_step(), - decay_steps=10000, - decay_rate=0.96)) - - # Or estimator with warm-starting from a previous checkpoint. - estimator = DNNEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3), - feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], - hidden_units=[1024, 512, 256], - warm_start_from="/path/to/checkpoint/dir") - - # Input builders - def input_fn_train: # returns x, y - pass - estimator.train(input_fn=input_fn_train, steps=100) - - def input_fn_eval: # returns x, y - pass - metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) - def input_fn_predict: # returns x, None - pass - predictions = estimator.predict(input_fn=input_fn_predict) - ``` - - Input of `train` and `evaluate` should have following features, - otherwise there will be a `KeyError`: - - * if `weight_column` is not `None`, a feature with - `key=weight_column` whose value is a `Tensor`. - * for each `column` in `feature_columns`: - - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` - whose `value` is a `SparseTensor`. - - if `column` is a `_WeightedCategoricalColumn`, two features: the first - with `key` the id column name, the second with `key` the weight column - name. Both features' `value` must be a `SparseTensor`. - - if `column` is a `_DenseColumn`, a feature with `key=column.name` - whose `value` is a `Tensor`. - - Loss and predicted output are determined by the specified head. - """ - - def __init__(self, - head, - hidden_units, - feature_columns, - model_dir=None, - optimizer='Adagrad', - activation_fn=nn.relu, - dropout=None, - input_layer_partitioner=None, - config=None, - warm_start_from=None, - batch_norm=False): - """Initializes a `DNNEstimator` instance. - - Args: - head: A `_Head` instance constructed with a method such as - `tf.contrib.estimator.multi_label_head`. - hidden_units: Iterable of number hidden units per layer. All layers are - fully connected. Ex. `[64, 32]` means first layer has 64 nodes and - second one has 32. - feature_columns: An iterable containing all the feature columns used by - the model. All items in the set should be instances of classes derived - from `_FeatureColumn`. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. - optimizer: An instance of `tf.Optimizer` used to train the model. Can also - be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or - callable. Defaults to Adagrad optimizer. - activation_fn: Activation function applied to each layer. If `None`, will - use `tf.nn.relu`. - dropout: When not `None`, the probability we will drop out a given - coordinate. - input_layer_partitioner: Optional. Partitioner for input layer. Defaults - to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. - config: `RunConfig` object to configure the runtime settings. - warm_start_from: A string filepath to a checkpoint to warm-start from, or - a `WarmStartSettings` object to fully configure warm-starting. If the - string filepath is provided instead of a `WarmStartSettings`, then all - weights are warm-started, and it is assumed that vocabularies and Tensor - names are unchanged. - batch_norm: Whether to use batch normalization after each hidden layer. - """ - def _model_fn(features, labels, mode, config): - return dnn_lib._dnn_model_fn( # pylint: disable=protected-access - features=features, - labels=labels, - mode=mode, - head=head, - hidden_units=hidden_units, - feature_columns=tuple(feature_columns or []), - optimizer=optimizer, - activation_fn=activation_fn, - dropout=dropout, - input_layer_partitioner=input_layer_partitioner, - config=config, - batch_norm=batch_norm) - super(DNNEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config, - warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py deleted file mode 100644 index 724bc2c82f8289bbaa19a1dbbc1dc81b6e158e02..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorFlow estimator for Linear and DNN joined training models.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator.canned import dnn_linear_combined as dnn_linear_combined_lib -from tensorflow.python.ops import nn - - -class DNNLinearCombinedEstimator(estimator.Estimator): - """An estimator for TensorFlow Linear and DNN joined models with custom head. - - Note: This estimator is also known as wide-n-deep. - - Example: - - ```python - numeric_feature = numeric_column(...) - categorical_column_a = categorical_column_with_hash_bucket(...) - categorical_column_b = categorical_column_with_hash_bucket(...) - - categorical_feature_a_x_categorical_feature_b = crossed_column(...) - categorical_feature_a_emb = embedding_column( - categorical_column=categorical_feature_a, ...) - categorical_feature_b_emb = embedding_column( - categorical_column=categorical_feature_b, ...) - - estimator = DNNLinearCombinedEstimator( - head=tf.contrib.estimator.multi_label_head(n_classes=3), - # wide settings - linear_feature_columns=[categorical_feature_a_x_categorical_feature_b], - linear_optimizer=tf.train.FtrlOptimizer(...), - # deep settings - dnn_feature_columns=[ - categorical_feature_a_emb, categorical_feature_b_emb, - numeric_feature], - dnn_hidden_units=[1000, 500, 100], - dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) - - # To apply L1 and L2 regularization, you can set dnn_optimizer to: - tf.train.ProximalAdagradOptimizer( - learning_rate=0.1, - l1_regularization_strength=0.001, - l2_regularization_strength=0.001) - # To apply learning rate decay, you can set dnn_optimizer to a callable: - lambda: tf.AdamOptimizer( - learning_rate=tf.exponential_decay( - learning_rate=0.1, - global_step=tf.get_global_step(), - decay_steps=10000, - decay_rate=0.96) - # It is the same for linear_optimizer. - - # Input builders - def input_fn_train: # returns x, y - pass - estimator.train(input_fn=input_fn_train, steps=100) - - def input_fn_eval: # returns x, y - pass - metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) - def input_fn_predict: # returns x, None - pass - predictions = estimator.predict(input_fn=input_fn_predict) - ``` - - Input of `train` and `evaluate` should have following features, - otherwise there will be a `KeyError`: - - * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: - - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` - whose `value` is a `SparseTensor`. - - if `column` is a `_WeightedCategoricalColumn`, two features: the first - with `key` the id column name, the second with `key` the weight column - name. Both features' `value` must be a `SparseTensor`. - - if `column` is a `_DenseColumn`, a feature with `key=column.name` - whose `value` is a `Tensor`. - - Loss is calculated by using mean squared error. - - @compatibility(eager) - Estimators are not compatible with eager execution. - @end_compatibility - """ - - def __init__(self, - head, - model_dir=None, - linear_feature_columns=None, - linear_optimizer='Ftrl', - dnn_feature_columns=None, - dnn_optimizer='Adagrad', - dnn_hidden_units=None, - dnn_activation_fn=nn.relu, - dnn_dropout=None, - input_layer_partitioner=None, - config=None, - linear_sparse_combiner='sum'): - """Initializes a DNNLinearCombinedEstimator instance. - - Args: - head: A `_Head` instance constructed with a method such as - `tf.contrib.estimator.multi_label_head`. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - linear_feature_columns: An iterable containing all the feature columns - used by linear part of the model. All items in the set must be - instances of classes derived from `FeatureColumn`. - linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the linear part of the model. Can also be a string (one of 'Adagrad', - 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL - optimizer. - dnn_feature_columns: An iterable containing all the feature columns used - by deep part of the model. All items in the set must be instances of - classes derived from `FeatureColumn`. - dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to - the deep part of the model. Can also be a string (one of 'Adagrad', - 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad - optimizer. - dnn_hidden_units: List of hidden units per layer. All layers are fully - connected. - dnn_activation_fn: Activation function applied to each layer. If None, - will use `tf.nn.relu`. - dnn_dropout: When not None, the probability we will drop out - a given coordinate. - input_layer_partitioner: Partitioner for input layer. Defaults to - `min_max_variable_partitioner` with `min_slice_size` 64 << 20. - config: RunConfig object to configure the runtime settings. - linear_sparse_combiner: A string specifying how to reduce the linear model - if a categorical column is multivalent. One of "mean", "sqrtn", and - "sum" -- these are effectively different ways to do example-level - normalization, which can be useful for bag-of-words features. For more - details, see `tf.feature_column.linear_model`. - - Raises: - ValueError: If both linear_feature_columns and dnn_features_columns are - empty at the same time. - """ - linear_feature_columns = linear_feature_columns or [] - dnn_feature_columns = dnn_feature_columns or [] - self._feature_columns = ( - list(linear_feature_columns) + list(dnn_feature_columns)) - if not self._feature_columns: - raise ValueError('Either linear_feature_columns or dnn_feature_columns ' - 'must be defined.') - - def _model_fn(features, labels, mode, config): - return dnn_linear_combined_lib._dnn_linear_combined_model_fn( # pylint: disable=protected-access - features=features, - labels=labels, - mode=mode, - head=head, - linear_feature_columns=linear_feature_columns, - linear_optimizer=linear_optimizer, - dnn_feature_columns=dnn_feature_columns, - dnn_optimizer=dnn_optimizer, - dnn_hidden_units=dnn_hidden_units, - dnn_activation_fn=dnn_activation_fn, - dnn_dropout=dnn_dropout, - input_layer_partitioner=input_layer_partitioner, - config=config, - linear_sparse_combiner=linear_sparse_combiner) - - super(DNNLinearCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py deleted file mode 100644 index 51b9ce7005cec3910ba73db62a674e4628ca30a2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for dnn_linear_combined.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -import numpy as np -import six - -from tensorflow.contrib.estimator.python.estimator import dnn_linear_combined -from tensorflow.contrib.estimator.python.estimator import head as head_lib -from tensorflow.python.estimator.canned import dnn_testing_utils -from tensorflow.python.estimator.canned import linear_testing_utils -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column -from tensorflow.python.framework import ops -from tensorflow.python.ops import nn -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -def _dnn_only_estimator_fn( - hidden_units, - feature_columns, - model_dir=None, - label_dimension=1, - weight_column=None, - optimizer='Adagrad', - activation_fn=nn.relu, - dropout=None, - input_layer_partitioner=None, - config=None): - return dnn_linear_combined.DNNLinearCombinedEstimator( - head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension, - # Tests in core (from which this test inherits) test the sum loss. - loss_reduction=losses.Reduction.SUM), - model_dir=model_dir, - dnn_feature_columns=feature_columns, - dnn_optimizer=optimizer, - dnn_hidden_units=hidden_units, - dnn_activation_fn=activation_fn, - dnn_dropout=dropout, - input_layer_partitioner=input_layer_partitioner, - config=config) - - -class DNNOnlyEstimatorEvaluateTest( - dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( - self, _dnn_only_estimator_fn) - - -class DNNOnlyEstimatorPredictTest( - dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( - self, _dnn_only_estimator_fn) - - -class DNNOnlyEstimatorTrainTest( - dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( - self, _dnn_only_estimator_fn) - - -def _linear_only_estimator_fn( - feature_columns, - model_dir=None, - label_dimension=1, - weight_column=None, - optimizer='Ftrl', - config=None, - partitioner=None, - sparse_combiner='sum'): - return dnn_linear_combined.DNNLinearCombinedEstimator( - head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension, - # Tests in core (from which this test inherits) test the sum loss. - loss_reduction=losses.Reduction.SUM), - model_dir=model_dir, - linear_feature_columns=feature_columns, - linear_optimizer=optimizer, - input_layer_partitioner=partitioner, - config=config, - linear_sparse_combiner=sparse_combiner) - - -class LinearOnlyEstimatorEvaluateTest( - linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( - self, _linear_only_estimator_fn) - - -class LinearOnlyEstimatorPredictTest( - linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - linear_testing_utils.BaseLinearRegressorPredictTest.__init__( - self, _linear_only_estimator_fn) - - -class LinearOnlyEstimatorTrainTest( - linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( - self, _linear_only_estimator_fn) - - -class DNNLinearCombinedEstimatorIntegrationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, - label_dimension, batch_size): - linear_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,))] - dnn_feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,))] - feature_columns = linear_feature_columns + dnn_feature_columns - est = dnn_linear_combined.DNNLinearCombinedEstimator( - head=head_lib.regression_head(label_dimension=label_dimension), - linear_feature_columns=linear_feature_columns, - dnn_feature_columns=dnn_feature_columns, - dnn_hidden_units=(2, 2), - model_dir=self._model_dir) - - # TRAIN - num_steps = 10 - est.train(train_input_fn, steps=num_steps) - - # EVALUTE - scores = est.evaluate(eval_input_fn) - self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) - - # PREDICT - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in est.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - # EXPORT - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = est.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def test_numpy_input_fn(self): - """Tests complete flow with numpy_input_fn.""" - label_dimension = 2 - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - # learn y = x - train_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - batch_size=batch_size, - shuffle=False) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - input_dimension=label_dimension, - label_dimension=label_dimension, - batch_size=batch_size) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py deleted file mode 100644 index 050b0428bf7b685229e12561cfb0682d931299d2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for dnn.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -import numpy as np -import six - -from tensorflow.contrib.estimator.python.estimator import dnn -from tensorflow.contrib.estimator.python.estimator import head as head_lib -from tensorflow.python.estimator.canned import dnn_testing_utils -from tensorflow.python.estimator.canned import prediction_keys -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.feature_column import feature_column -from tensorflow.python.framework import ops -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache - - -def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """Returns a DNNEstimator that uses regression_head.""" - return dnn.DNNEstimator( - head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension, - # Tests in core (from which this test inherits) test the sum loss. - loss_reduction=losses.Reduction.SUM), - *args, **kwargs) - - -def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """Returns a DNNEstimator that uses multi_class_head.""" - return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes), - *args, **kwargs) - - -class DNNEstimatorEvaluateTest( - dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( - self, _dnn_estimator_fn) - - -class DNNEstimatorPredictTest( - dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( - self, _dnn_estimator_fn) - - -class DNNEstimatorTrainTest( - dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( - self, _dnn_estimator_fn) - - -class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, - test.TestCase): - - def __init__(self, methodName='runTest'): # pylint: disable=invalid-name - test.TestCase.__init__(self, methodName) - dnn_testing_utils.BaseDNNWarmStartingTest.__init__( - self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) - - -class DNNEstimatorIntegrationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, - label_dimension, batch_size): - feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,))] - est = dnn.DNNEstimator( - head=head_lib.regression_head(label_dimension=label_dimension), - hidden_units=(2, 2), - feature_columns=feature_columns, - model_dir=self._model_dir) - - # TRAIN - num_steps = 10 - est.train(train_input_fn, steps=num_steps) - - # EVALUTE - scores = est.evaluate(eval_input_fn) - self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', six.iterkeys(scores)) - - # PREDICT - predictions = np.array([ - x[prediction_keys.PredictionKeys.PREDICTIONS] - for x in est.predict(predict_input_fn) - ]) - self.assertAllEqual((batch_size, label_dimension), predictions.shape) - - # EXPORT - feature_spec = feature_column.make_parse_example_spec(feature_columns) - serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( - feature_spec) - export_dir = est.export_savedmodel(tempfile.mkdtemp(), - serving_input_receiver_fn) - self.assertTrue(gfile.Exists(export_dir)) - - def test_numpy_input_fn(self): - """Tests complete flow with numpy_input_fn.""" - label_dimension = 2 - batch_size = 10 - data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) - data = data.reshape(batch_size, label_dimension) - # learn y = x - train_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - batch_size=batch_size, - shuffle=False) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - input_dimension=label_dimension, - label_dimension=label_dimension, - batch_size=batch_size) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py index 6ca7aaf98972c76c608c9c397a82ca94286a2656..854d2e4011b40428b8048e9d61411f66c1bb3840 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,425 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Deep Neural Network estimators with layer annotations.""" +"""dnn_with_layer_annotations python module. + +Importing from tensorflow.python.estimator is unsupported +and will soon break! +""" +# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import pickle - -from google.protobuf.any_pb2 import Any - -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator.canned import dnn -from tensorflow.python.feature_column import feature_column as feature_column_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import nn -from tensorflow.python.ops.losses import losses -from tensorflow.python.saved_model import utils as saved_model_utils - - -class LayerAnnotationsCollectionNames(object): - """Names for the collections containing the annotations.""" - - UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features' - PROCESSED_FEATURES = 'layer_annotatons/processed_features' - FEATURE_COLUMNS = 'layer_annotations/feature_columns' - - @classmethod - def keys(cls, collection_name): - return '%s/keys' % collection_name - - @classmethod - def values(cls, collection_name): - return '%s/values' % collection_name - - -def serialize_feature_column(feature_column): - if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access - # We can't pickle nested functions, and we don't need the value of - # layer_creator in most cases anyway, so just discard its value. - args = feature_column._asdict() - args['layer_creator'] = None - temp = type(feature_column)(**args) - return pickle.dumps(temp) - return pickle.dumps(feature_column) - - -def _to_any_wrapped_tensor_info(tensor): - """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`.""" - any_buf = Any() - tensor_info = saved_model_utils.build_tensor_info(tensor) - any_buf.Pack(tensor_info) - return any_buf - - -def make_input_layer_with_layer_annotations(original_input_layer): - """Make an input_layer replacement function that adds layer annotations.""" - - def input_layer_with_layer_annotations(features, - feature_columns, - weight_collections=None, - trainable=True, - cols_to_vars=None, - scope=None, - cols_to_output_tensors=None, - from_template=False): - """Returns a dense `Tensor` as input layer based on given `feature_columns`. - - Generally a single example in training data is described with - FeatureColumns. - At the first layer of the model, this column oriented data should be - converted - to a single `Tensor`. - - This is like tf.feature_column.input_layer, except with added - Integrated-Gradient annotations. - - Args: - features: A mapping from key to tensors. `_FeatureColumn`s look up via - these keys. For example `numeric_column('price')` will look at 'price' - key in this dict. Values can be a `SparseTensor` or a `Tensor` depends - on corresponding `_FeatureColumn`. - feature_columns: An iterable containing the FeatureColumns to use as - inputs to your model. All items should be instances of classes derived - from `_DenseColumn` such as `numeric_column`, `embedding_column`, - `bucketized_column`, `indicator_column`. If you have categorical - features, you can wrap them with an `embedding_column` or - `indicator_column`. - weight_collections: A list of collection names to which the Variable will - be added. Note that variables will also be added to collections - `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. - trainable: If `True` also add the variable to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - cols_to_vars: If not `None`, must be a dictionary that will be filled with - a mapping from `_FeatureColumn` to list of `Variable`s. For example, - after the call, we might have cols_to_vars = {_EmbeddingColumn( - categorical_column=_HashedCategoricalColumn( key='sparse_feature', - hash_bucket_size=5, dtype=tf.string), dimension=10): [