diff --git a/.gitignore b/.gitignore index 900e5a53cbcf3bbb5e00389cca004c49f8600a66..bdcb067fc26d2a18ed88034ab616c08095794e17 100644 --- a/.gitignore +++ b/.gitignore @@ -4,12 +4,11 @@ node_modules /.bazelrc /.tf_configure.bazelrc /bazel-* -/third_party/py/numpy/numpy_include -/tools/bazel.rc +/bazel_pip +/third_party/eigen3/mkl_include +/third_party/mkl/* /tools/python_bin_path.sh /tools/git/gen -/util/python/python_include -/util/python/python_lib /pip_test /_python_build *.pyc diff --git a/.mention-bot b/.mention-bot deleted file mode 100644 index 9e4858977f5da2992ccc4053dfbbda3f5f86ee90..0000000000000000000000000000000000000000 --- a/.mention-bot +++ /dev/null @@ -1,11 +0,0 @@ -{ - "maxReviewers": 2, - "numFilesToCheck": 10, - "userBlacklist": ["tensorflower-gardener"], - "requiredOrgs": ["tensorflow"], - "skipAlreadyAssignedPR": true, - "skipAlreadyMentionedPR": true, - "skipTitle": "Branch", - "delayed": true, - "delayedUntil": "10m" -} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5ae5c0fbbcd5b8da7e3f3f98e01f455e0c82e588..c78b6b1a150c98fa379a87f935e77b5803837f11 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,3 +27,140 @@ contributions, often because we probably won't get to them right now. If you decide to start on an issue, leave a comment so that other people know that you're working on it. If you want to help out, but not alone, use the issue comment thread to coordinate. + +### Contribution guidelines and standards + +Before sending your pull request for +[review](https://github.com/tensorflow/tensorflow/pulls), +make sure your changes are consistent with the guidelines and follow the +TensorFlow coding style. + +#### General guidelines and philosophy for contribution + +* Include unit tests when you contribute new features, as they help to + a) prove that your code works correctly, b) guard against future breaking + changes to lower the maintenance cost. +* Bug fixes also generally require unit tests, because the presence of bugs + usually indicates insufficient test coverage. +* Keep API compatibility in mind when you change code in core TensorFlow, + e.g., code in [tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core) and [tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python). + TensorFlow has reached version 1 and hence cannot make + non-backward-compatible API changes without a major release. Reviewers of your + pull request will comment on any API compatibility issues. +* When you contribute a new feature to TensorFlow, the maintenance burden is (by + default) transferred to the TensorFlow team. This means that benefit of + contribution must be compared against the cost of maintaining the feature. +* Full new features (e.g., a new op implementing a cutting-edge algorithm) + typically will live in + [tensorflow/contrib](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib) + to get some airtime before decision is made regarding whether they are to be + migrated to the core. + +#### License + +Include a license at the top of new files. + +* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1) +* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1) +* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) +* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) +* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) +* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) + +Bazel BUILD files also need to include a license section, e.g., +[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). + +#### C++ coding style + +Changes to TensorFlow C++ code should conform to +[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). + +Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: + +```bash +apt-get install -y clang-tidy +``` + +You can check a C/C++ file by doing: + + +```bash +clang-format --style=google > /tmp/my_cc_file.cc +diff /tmp/my_cc_file.cc +``` + +#### Python coding style + +Changes to TensorFlow Python code should conform to +[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) + +Use `pylint` to check your Python changes. To install `pylint` and +retrieve TensorFlow's custom style definition: + +```bash +pip install pylint +wget -O /tmp/pylintrc https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/ci_build/pylintrc +``` + +To check a file with `pylint`: + +```bash +pylint --rcfile=/tmp/pylintrc myfile.py +``` + +#### Coding style for other languages + +* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) +* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) +* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) + +#### Running sanity check + +If you have Docker installed on your system, you can perform a sanity check on +your changes by running the command: + +```bash +tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh +``` + +This will catch most license, Python coding style and BUILD file issues that +may exist in your changes. + +#### Running unit tests + +There are two ways to run TensorFlow unit tests. + +1. Using tools and libraries installed directly on your system. + + Refer to the + [CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel) and + [GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/docker/Dockerfile.devel-gpu) + for the required packages. Alternatively, use the said + [Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g., + `tensorflow/tensorflow:nightly-devel` and `tensorflow/tensorflow:nightly-devel-gpu` + for development to avoid installing the packages directly on your system. + + Once you have the packages installed, you can run a specific unit test in + bazel by doing as follows: + + If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add + the `cuda` option flag + + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + + export flags="--config=opt --config=cuda -k" + ``` + + For example, to run all tests under tensorflow/python, do: + + ```bash + bazel test ${flags} //tensorflow/python/... + ``` + +2. Using Docker and TensorFlow's CI scripts. + + See + [TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build) for details. + diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index af76188c2f4d2e1908f541918c8b680627a90cf9..6f4c048ce83fb47a611b5dfe08e0fde0779994c0 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -1,36 +1,36 @@ -NOTE: Only file GitHub issues for bugs and feature requests. All other topics will be closed. +Please go to Stack Overflow for help and support: -For general support from the community, see [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow). -To make bugs and feature requests more easy to find and organize, we close issues that are deemed -out of scope for GitHub Issues and point people to StackOverflow. +http://stackoverflow.com/questions/tagged/tensorflow -For bugs or installation issues, please provide the following information. -The more information you provide, the more easily we will be able to offer -help and advice. +If you open a GitHub issue, here is our policy: -### What related GitHub issues or StackOverflow threads have you found by searching the web for your problem? +1. It must be a bug or a feature request. +2. The form below must be filled out. -### Environment info -Operating System: +**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. -Installed version of CUDA and cuDNN: -(please attach the output of `ls -l /path/to/cuda/lib/libcud*`): +------------------------ -If installed from binary pip package, provide: +### System information +- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: +- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: +- **TensorFlow installed from (source or binary)**: +- **TensorFlow version (use command below)**: +- **Bazel version (if compiling from source)**: +- **CUDA/cuDNN version**: +- **GPU model and memory**: +- **Exact command to reproduce**: -1. A link to the pip package you installed: -2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`. +You can collect some of this information using our environment capture script: -If installed from source, provide +https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh -1. The commit hash (`git rev-parse HEAD`) -2. The output of `bazel version` +You can obtain the TensorFlow version with -### If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code) +python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" +### Describe the problem +Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request. -### What other attempted solutions have you tried? - - -### Logs or other output that would be helpful -(If logs are large, please upload as attachment or provide link). +### Source code / logs +Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem. diff --git a/README.md b/README.md index d9f05a67e0391fb5817cf1bd1ac492e3b3cce71d..2878dab2601351dabbfbcadfbe6a4ae94864ce56 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ guidelines](CONTRIBUTING.md).** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for tracking requests and bugs, but please see -[Community](tensorflow/docs_src/about/index.md#community) for general questions +[Community](https://www.tensorflow.org/community/) for general questions and discussion.** ## Installation @@ -34,12 +34,12 @@ and discussion.** People who are a little more adventurous can also try our nightly binaries: -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/)) +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) +* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) @@ -52,7 +52,7 @@ $ python >>> hello = tf.constant('Hello, TensorFlow!') >>> sess = tf.Session() >>> sess.run(hello) -Hello, TensorFlow! +'Hello, TensorFlow!' >>> a = tf.constant(10) >>> b = tf.constant(32) >>> sess.run(a+b) @@ -62,7 +62,7 @@ Hello, TensorFlow! ## For more information -* [TensorFlow website](http://tensorflow.org) +* [TensorFlow website](https://tensorflow.org) * [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) diff --git a/RELEASE.md b/RELEASE.md index 156cc2e3af507ffa416a1a96b2d37caa4d87c2e5..02bdbd429772a79d2f8f9af6012b6ac3916c822f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,35 @@ +# Changes since the last release + +## Major Features and Improvements +* Added `tf.layers.conv3d_transpose` layer for spatio temporal deconvolution. +* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times. +* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo). +* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described + in the TensorFlow 1.1 release is gone: The first time an RNNCell is used, + it caches its scope. All future uses of the RNNCell will reuse variables from + that same scope. This is a breaking change from the behavior of RNNCells + in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to + ensure old code works correctly with the new semantics; this version + allows more flexible uses of RNNCell but can lead to subtle errors if + using code meant for TensorFlow <= 1.0.1. For example, writing: + `MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each + layer shares the **same** parameters. To get 5 layers each with their own + parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`. + If at all unsure, first test your code with TF 1.1; ensure it raises no + errors, and then upgrade to TF 1.2. + +## Bug Fixes and Other Changes +* In python, `Operation.get_attr` on type attributes returns the Python DType + version of the type to match expected get_attr documentation rather than the + protobuf enum. +* tensorflow/contrib/rnn undergoes RNN cell variable renaming for + consistency with Keras layers. Specifically, the previous variable names + "weights" and "biases" are changed to "kernel" and "bias", respectively. + This may cause backward incompatibility with regard to your old + checkpoints containing such RNN cells, in which case you can use the + [checkpoint_convert script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/tools/checkpoint_convert.py) + to convert the variable names in your old checkpoints. + # Release 1.1.0 ## Major Features and Improvements @@ -15,6 +47,8 @@ * Ability to inspect Python source file against TF ops and tensors (command `print_source` / `ps`) * New navigation bar in Curses-based UI * NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption. +* Initial release of installation guides for Java, C, and Go. +* Added Text Dashboard to TensorBoard. ## Deprecations @@ -68,38 +102,41 @@ * Multiple tfdbg bug fixes: * Fixed Windows compatibility issues. * Command history now persists across runs. + * Bug fix in graph validation related to `tf.while_loops`. +* Java Maven fixes for bugs with Windows installation. +* Backport fixes and improvements from external keras. +* Keras config file handling fix. ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex -Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton +Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben -Visser, Brady Zhou, Calpa Liu, Changming Sun, Chi Zeng, Chih Cheng Liang, -Christopher Berner, Clark Zinzow, @Conchylicultor, Courtial Florian, Dan Ellis, -Dan J, Dan Jarvis, Daniel Ylitalo, Darren Garvey, David Norman, David Truong, -@DavidNorman, Dimitar Pavlov, Dmitry Persiyanov, @Eddie, @elirex, Erfan -Noury, Eron Wright, Evgeny Mazovetskiy, Fabrizio (Misto) Milo, @fanlu, Fisher -Coder, Franck Dernoncourt, Gagan Goel, Gao, Xiang, @Gautam, Gefu Tang, -@guilherme, @guschmue, Hannah Provenza, Hans Pabst, @hartb, Hsiao Yi, Huazuo -Gao, Igor ChorążEwicz, Ivan Smirnov, Jakub Kolodziejczyk, Jason Gavris, Jason -Morton, Jay Young, Jayaram Bobba, Jeremy Sawruk, Jiaming Liu, Jihun Choi, -@jiqiu, Joan Thibault, John C F, Jojy G Varghese, Jon Malmaud, Julian Berman, -Julian Niedermeier, Junpeng Lao, Kai Sasaki, @Kankroc, Karl Lessard, Kyle -Bostelmann, @Lezcano, Li Yi, Luo Yun, @lurker, Mahmoud-Abuzaina, Mandeep Singh, -Marek Kolodziej, Mark Szepieniec, Martial Hue, Medhat Omr, Memo Akten, Michael -Gharbi, MichaëL Defferrard, Milan Straka, @MircoT, @mlucool, Muammar Ibn Faisal, -Nayana Thorat, @nghiattran, Nicholas Connor, Nikolaas Steenbergen, Niraj Patel, -Niranjan Hasabnis, @Panmari, Pavel Bulanov, Philip Pries Henningsen, Philipp -Jund, @polonez, Prayag Verma, Rahul Kavi, Raphael Gontijo Lopes, @rasbt, Raven -Iqqe, Reid Pryzant, Richard Shin, Rizwan Asif, Russell Kaplan, Ryo Asakura, -RüDiger Busche, Saisai Shao, Sam Abrahams, @sanosay, Sean Papay, @seaotterman, -@selay01, Shaurya Sharma, Sriram Narayanamoorthy, Stefano Probst, @taknevski, -@tbonza, @teldridge11, Yuan (Terry) Tang, Tim Anglade, Tomas Reimers, Tomer Gafner, -Valentin Iovene, Vamsi Sripathi, Viktor Malyi, Vit Stepanovs, Vivek Rane, Vlad -Firoiu, @wangg12, @will, Xiaoyu Tao, Yaroslav Bulatov, Yuan (Terry) Tang, -@Yufeng, Yuming Wang, Yuxin Wu, Zafar Takhirov, Ziming Dong +Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher +Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel +Ylitalo, Darren Garvey, David Norman, David Truong, @DavidNorman, Dimitar +Pavlov, Dmitry Persiyanov, @Eddie, @elirex, Erfan Noury, Eron Wright, Evgeny +Mazovetskiy, Fabrizio (Misto) Milo, @fanlu, Fisher Coder, Florian Courtial, +Franck Dernoncourt, Gagan Goel, Gao, Xiang, @Gautam, Gefu Tang, @guilherme, +@guschmue, Hannah Provenza, Hans Pabst, @hartb, Hsiao Yi, Huazuo Gao, Igor +ChorążEwicz, Ivan Smirnov, Jakub Kolodziejczyk, Jason Gavris, Jason Morton, Jay +Young, Jayaram Bobba, Jeremy Sawruk, Jiaming Liu, Jihun Choi, @jiqiu, Joan Thibault, +John C F, Jojy George Varghese, Jon Malmaud, Julian Berman, Julian Niedermeier, +Junpeng Lao, Kai Sasaki, @Kankroc, Karl Lessard, Kyle Bostelmann, @Lezcano, Li +Yi, Luo Yun, @lurker, Mahmoud-Abuzaina, Mandeep Singh, Marek Kolodziej, Mark +Szepieniec, Martial Hue, Medhat Omr, Memo Akten, Michael Gharbi, MichaëL Defferrard, +Milan Straka, @MircoT, @mlucool, Muammar Ibn Faisal, Nayana Thorat, @nghiattran, +Nicholas Connor, Nikolaas Steenbergen, Niraj Patel, Niranjan Hasabnis, @Panmari, +Pavel Bulanov, Philip Pries Henningsen, Philipp Jund, @polonez, Prayag Verma, Rahul +Kavi, Raphael Gontijo Lopes, @rasbt, Raven Iqqe, Reid Pryzant, Richard Shin, Rizwan +Asif, Russell Kaplan, Ryo Asakura, RüDiger Busche, Saisai Shao, Sam Abrahams, @sanosay, +Sean Papay, @seaotterman, @selay01, Shaurya Sharma, Sriram Narayanamoorthy, Stefano +Probst, @taknevski, @tbonza, @teldridge11, Tim Anglade, Tomas Reimers, Tomer Gafner, +Valentin Iovene, Vamsi Sripathi, Viktor Malyi, Vit Stepanovs, Vivek Rane, Vlad Firoiu, +@wangg12, @will, Xiaoyu Tao, Yaroslav Bulatov, Yi Liu, Yuan (Terry) Tang, @Yufeng, +Yuming Wang, Yuxin Wu, Zafar Takhirov, Ziming Dong We are also grateful to all who filed issues or helped resolve them, asked and answered questions, and were part of inspiring discussions. diff --git a/WORKSPACE b/WORKSPACE index cab8389a55ccfeddb9dc077c9b999edbe775f25d..edf655f6a7b0ab2781cf2d349732a102aedff112 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce", - strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d", + sha256 = "4be8a887f6f38f883236e77bb25c2da10d506f2bf1a8e5d785c0f35574c74ca4", + strip_prefix = "rules_closure-aac19edc557aec9b603cd7ffe359401264ceff0d", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03 - "https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", # 2017-05-10 + "https://github.com/bazelbuild/rules_closure/archive/aac19edc557aec9b603cd7ffe359401264ceff0d.tar.gz", ], ) @@ -20,7 +20,7 @@ load("//tensorflow:workspace.bzl", "tf_workspace") #android_sdk_repository( # name = "androidsdk", # api_level = 23, -# # Ensure that you have the build_tools_version below installed in the +# # Ensure that you have the build_tools_version below installed in the # # SDK manager as it updates periodically. # build_tools_version = "25.0.2", # # Replace with path to Android SDK on your system @@ -31,7 +31,7 @@ load("//tensorflow:workspace.bzl", "tf_workspace") #android_ndk_repository( # name="androidndk", # path="", -# # This needs to be 14 or higher to compile TensorFlow. +# # This needs to be 14 or higher to compile TensorFlow. # # Note that the NDK version is not the API level. # api_level=14) @@ -39,485 +39,31 @@ load("//tensorflow:workspace.bzl", "tf_workspace") tf_workspace() new_http_archive( - name = "inception5h", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", - sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364" -) - -new_http_archive( - name = "mobile_multibox", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", - sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96" -) - -new_http_archive( - name = "stylize", - build_file = "models.BUILD", - url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", - sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa" -) - -# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT - -new_http_archive( - name = "d3", - build_file = "bower.BUILD", - url = "https://github.com/mbostock-bower/d3-bower/archive/v3.5.15.tar.gz", - strip_prefix = "d3-bower-3.5.15", -) - -new_http_archive( - name = "dagre", - build_file = "bower.BUILD", - url = "https://github.com/cpettitt/dagre/archive/v0.7.4.tar.gz", - strip_prefix = "dagre-0.7.4", -) - -new_http_archive( - name = "es6_promise", - build_file = "bower.BUILD", - url = "https://github.com/components/es6-promise/archive/v2.1.0.tar.gz", - strip_prefix = "es6-promise-2.1.0", -) - -new_http_archive( - name = "font_roboto", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/font-roboto/archive/v1.0.1.tar.gz", - strip_prefix = "font-roboto-1.0.1", -) - -new_http_archive( - name = "graphlib", - build_file = "bower.BUILD", - url = "https://github.com/cpettitt/graphlib/archive/v1.0.7.tar.gz", - strip_prefix = "graphlib-1.0.7", -) - -new_http_archive( - name = "iron_a11y_announcer", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - strip_prefix = "iron-a11y-announcer-1.0.5", -) - -new_http_archive( - name = "iron_a11y_keys_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - strip_prefix = "iron-a11y-keys-behavior-1.1.8", -) - -new_http_archive( - name = "iron_ajax", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-ajax/archive/v1.2.0.tar.gz", - strip_prefix = "iron-ajax-1.2.0", -) - -new_http_archive( - name = "iron_autogrow_textarea", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - strip_prefix = "iron-autogrow-textarea-1.0.12", -) - -new_http_archive( - name = "iron_behaviors", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-behaviors/archive/v1.0.17.tar.gz", - strip_prefix = "iron-behaviors-1.0.17", -) - -new_http_archive( - name = "iron_checked_element_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - strip_prefix = "iron-checked-element-behavior-1.0.4", -) - -new_http_archive( - name = "iron_collapse", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-collapse/archive/v1.0.8.tar.gz", - strip_prefix = "iron-collapse-1.0.8", -) - -new_http_archive( - name = "iron_dropdown", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-dropdown/archive/v1.4.0.tar.gz", - strip_prefix = "iron-dropdown-1.4.0", -) - -new_http_archive( - name = "iron_fit_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-fit-behavior/archive/v1.2.5.tar.gz", - strip_prefix = "iron-fit-behavior-1.2.5", -) - -new_http_archive( - name = "iron_flex_layout", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-flex-layout/archive/v1.3.0.tar.gz", - strip_prefix = "iron-flex-layout-1.3.0", -) - -new_http_archive( - name = "iron_form_element_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - strip_prefix = "iron-form-element-behavior-1.0.6", -) - -new_http_archive( - name = "iron_icon", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-icon/archive/v1.0.11.tar.gz", - strip_prefix = "iron-icon-1.0.11", -) - -new_http_archive( - name = "iron_icons", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-icons/archive/v1.1.3.tar.gz", - strip_prefix = "iron-icons-1.1.3", -) - -new_http_archive( - name = "iron_iconset_svg", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz", - strip_prefix = "iron-iconset-svg-1.1.0", -) - -new_http_archive( - name = "iron_input", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-input/archive/1.0.10.tar.gz", - strip_prefix = "iron-input-1.0.10", -) - -new_http_archive( - name = "iron_list", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-list/archive/v1.3.9.tar.gz", - strip_prefix = "iron-list-1.3.9", -) - -new_http_archive( - name = "iron_menu_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-menu-behavior/archive/v1.1.10.tar.gz", - strip_prefix = "iron-menu-behavior-1.1.10", -) - -new_http_archive( - name = "iron_meta", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-meta/archive/v1.1.1.tar.gz", - strip_prefix = "iron-meta-1.1.1", -) - -new_http_archive( - name = "iron_overlay_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - strip_prefix = "iron-overlay-behavior-1.10.1", -) - -new_http_archive( - name = "iron_range_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-range-behavior/archive/v1.0.4.tar.gz", - strip_prefix = "iron-range-behavior-1.0.4", -) - -new_http_archive( - name = "iron_resizable_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - strip_prefix = "iron-resizable-behavior-1.0.3", -) - -new_http_archive( - name = "iron_scroll_target_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - strip_prefix = "iron-scroll-target-behavior-1.0.3", -) - -new_http_archive( - name = "iron_selector", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-selector/archive/v1.5.2.tar.gz", - strip_prefix = "iron-selector-1.5.2", -) - -new_http_archive( - name = "iron_validatable_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - strip_prefix = "iron-validatable-behavior-1.1.1", -) - -new_http_archive( - name = "lodash", - build_file = "bower.BUILD", - url = "https://github.com/lodash/lodash/archive/3.8.0.tar.gz", - strip_prefix = "lodash-3.8.0", -) - -new_http_archive( - name = "neon_animation", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/neon-animation/archive/v1.2.2.tar.gz", - strip_prefix = "neon-animation-1.2.2", -) - -http_file( - name = "numericjs_numeric_min_js", - url = "https://cdnjs.cloudflare.com/ajax/libs/numeric/1.2.6/numeric.min.js", -) - -new_http_archive( - name = "paper_behaviors", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-behaviors/archive/v1.0.12.tar.gz", - strip_prefix = "paper-behaviors-1.0.12", -) - -new_http_archive( - name = "paper_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-button/archive/v1.0.11.tar.gz", - strip_prefix = "paper-button-1.0.11", -) - -new_http_archive( - name = "paper_checkbox", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-checkbox/archive/v1.4.0.tar.gz", - strip_prefix = "paper-checkbox-1.4.0", -) - -new_http_archive( - name = "paper_dialog", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog/archive/v1.0.4.tar.gz", - strip_prefix = "paper-dialog-1.0.4", -) - -new_http_archive( - name = "paper_dialog_behavior", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - strip_prefix = "paper-dialog-behavior-1.2.5", -) - -new_http_archive( - name = "paper_dialog_scrollable", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - strip_prefix = "paper-dialog-scrollable-1.1.5", -) - -new_http_archive( - name = "paper_dropdown_menu", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - strip_prefix = "paper-dropdown-menu-1.4.0", -) - -new_http_archive( - name = "paper_header_panel", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-header-panel/archive/v1.1.4.tar.gz", - strip_prefix = "paper-header-panel-1.1.4", -) - -new_http_archive( - name = "paper_icon_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz", - strip_prefix = "paper-icon-button-1.1.3", -) - -new_http_archive( - name = "paper_input", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-input/archive/v1.1.18.tar.gz", - strip_prefix = "paper-input-1.1.18", -) - -new_http_archive( - name = "paper_item", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-item/archive/v1.1.4.tar.gz", - strip_prefix = "paper-item-1.1.4", -) - -new_http_archive( - name = "paper_listbox", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-listbox/archive/v1.1.2.tar.gz", - strip_prefix = "paper-listbox-1.1.2", -) - -new_http_archive( - name = "paper_material", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-material/archive/v1.0.6.tar.gz", - strip_prefix = "paper-material-1.0.6", -) - -new_http_archive( - name = "paper_menu", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-menu/archive/v1.2.2.tar.gz", - strip_prefix = "paper-menu-1.2.2", -) - -new_http_archive( - name = "paper_menu_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-menu-button/archive/v1.5.1.tar.gz", - strip_prefix = "paper-menu-button-1.5.1", -) - -new_http_archive( - name = "paper_progress", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-progress/archive/v1.0.9.tar.gz", - strip_prefix = "paper-progress-1.0.9", -) - -new_http_archive( - name = "paper_radio_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-radio-button/archive/v1.1.2.tar.gz", - strip_prefix = "paper-radio-button-1.1.2", -) - -new_http_archive( - name = "paper_radio_group", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-radio-group/archive/v1.0.9.tar.gz", - strip_prefix = "paper-radio-group-1.0.9", -) - -new_http_archive( - name = "paper_ripple", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-ripple/archive/v1.0.5.tar.gz", - strip_prefix = "paper-ripple-1.0.5", -) - -new_http_archive( - name = "paper_slider", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-slider/archive/v1.0.10.tar.gz", - strip_prefix = "paper-slider-1.0.10", -) - -new_http_archive( - name = "paper_spinner", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-spinner/archive/v1.1.1.tar.gz", - strip_prefix = "paper-spinner-1.1.1", -) - -new_http_archive( - name = "paper_styles", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-styles/archive/v1.1.4.tar.gz", - strip_prefix = "paper-styles-1.1.4", -) - -new_http_archive( - name = "paper_tabs", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-tabs/archive/v1.7.0.tar.gz", - strip_prefix = "paper-tabs-1.7.0", -) - -new_http_archive( - name = "paper_toast", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toast/archive/v1.3.0.tar.gz", - strip_prefix = "paper-toast-1.3.0", -) - -new_http_archive( - name = "paper_toggle_button", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toggle-button/archive/v1.2.0.tar.gz", - strip_prefix = "paper-toggle-button-1.2.0", -) - -new_http_archive( - name = "paper_toolbar", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-toolbar/archive/v1.1.4.tar.gz", - strip_prefix = "paper-toolbar-1.1.4", -) - -new_http_archive( - name = "paper_tooltip", - build_file = "bower.BUILD", - url = "https://github.com/polymerelements/paper-tooltip/archive/v1.1.2.tar.gz", - strip_prefix = "paper-tooltip-1.1.2", -) - -new_http_archive( - name = "plottable", - build_file = "bower.BUILD", - url = "https://github.com/palantir/plottable/archive/v1.16.1.tar.gz", - strip_prefix = "plottable-1.16.1", -) - -new_http_archive( - name = "polymer", - build_file = "bower.BUILD", - url = "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz", - strip_prefix = "polymer-1.7.0", -) - -new_http_archive( - name = "promise_polyfill", - build_file = "bower.BUILD", - url = "https://github.com/polymerlabs/promise-polyfill/archive/v1.0.0.tar.gz", - strip_prefix = "promise-polyfill-1.0.0", -) - -http_file( - name = "three_js_three_min_js", - url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/build/three.min.js", -) - -http_file( - name = "three_js_orbitcontrols_js", - url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/examples/js/controls/OrbitControls.js", + name = "inception5h", + build_file = "models.BUILD", + sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", + "http://download.tensorflow.org/models/inception5h.zip", + ], ) new_http_archive( - name = "web_animations_js", - build_file = "bower.BUILD", - url = "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - strip_prefix = "web-animations-js-2.2.1", + name = "mobile_multibox", + build_file = "models.BUILD", + sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", + "http://download.tensorflow.org/models/mobile_multibox_v1a.zip", + ], ) new_http_archive( - name = "webcomponentsjs", - build_file = "bower.BUILD", - url = "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - strip_prefix = "webcomponentsjs-0.7.22", -) - -http_file( - name = "weblas_weblas_js", - url = "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", + name = "stylize", + build_file = "models.BUILD", + sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", + "http://download.tensorflow.org/models/stylize_v1.zip", + ], ) diff --git a/bower.BUILD b/bower.BUILD deleted file mode 100644 index eabd1d6450728aab37ebeca6366009d74c6984b6..0000000000000000000000000000000000000000 --- a/bower.BUILD +++ /dev/null @@ -1,645 +0,0 @@ -# AUTOGENERATED FILE by tensorboard_bower_dependency_sync.py - -package(default_visibility = ["//visibility:public"]) - -filegroup( - name = "d3", - srcs = [ - "d3.js", - "d3.min.js", - "package.js", - ], -) - -filegroup( - name = "dagre", - srcs = [ - "dist/dagre.core.js", - "dist/dagre.core.min.js", - ], -) - -filegroup( - name = "es6_promise", - srcs = [ - "promise.js", - "promise.min.js", - ], -) - -filegroup( - name = "font_roboto", - srcs = ["roboto.html"], -) - -filegroup( - name = "graphlib", - srcs = [ - "dist/graphlib.core.js", - "dist/graphlib.core.min.js", - ], -) - -filegroup( - name = "iron_a11y_announcer", - srcs = [ - "index.html", - "iron-a11y-announcer.html", - ], -) - -filegroup( - name = "iron_a11y_keys_behavior", - srcs = [ - "index.html", - "iron-a11y-keys-behavior.html", - ], -) - -filegroup( - name = "iron_ajax", - srcs = [ - "index.html", - "iron-ajax.html", - "iron-request.html", - ], -) - -filegroup( - name = "iron_autogrow_textarea", - srcs = [ - "index.html", - "iron-autogrow-textarea.html", - ], -) - -filegroup( - name = "iron_behaviors", - srcs = [ - "index.html", - "iron-button-state.html", - "iron-control-state.html", - ], -) - -filegroup( - name = "iron_checked_element_behavior", - srcs = [ - "index.html", - "iron-checked-element-behavior.html", - ], -) - -filegroup( - name = "iron_collapse", - srcs = [ - "index.html", - "iron-collapse.html", - ], -) - -filegroup( - name = "iron_dropdown", - srcs = [ - "index.html", - "iron-dropdown.html", - "iron-dropdown-scroll-manager.html", - ], -) - -filegroup( - name = "iron_fit_behavior", - srcs = [ - "index.html", - "iron-fit-behavior.html", - ], -) - -filegroup( - name = "iron_flex_layout", - srcs = [ - "classes/iron-flex-layout.html", - "classes/iron-shadow-flex-layout.html", - "index.html", - "iron-flex-layout.html", - "iron-flex-layout-classes.html", - ], -) - -filegroup( - name = "iron_form_element_behavior", - srcs = [ - "index.html", - "iron-form-element-behavior.html", - ], -) - -filegroup( - name = "iron_icon", - srcs = [ - "index.html", - "iron-icon.html", - ], -) - -filegroup( - name = "iron_icons", - srcs = [ - "av-icons.html", - "communication-icons.html", - "device-icons.html", - "editor-icons.html", - "hardware-icons.html", - "image-icons.html", - "index.html", - "iron-icons.html", - "maps-icons.html", - "notification-icons.html", - "places-icons.html", - "social-icons.html", - ], -) - -filegroup( - name = "iron_iconset_svg", - srcs = [ - "index.html", - "iron-iconset-svg.html", - ], -) - -filegroup( - name = "iron_input", - srcs = [ - "index.html", - "iron-input.html", - ], -) - -filegroup( - name = "iron_list", - srcs = [ - "index.html", - "iron-list.html", - "test/smoke/avg-worst-case.html", - "test/smoke/dummy-data.html", - "test/smoke/index.html", - "test/smoke/physical-count.html", - ], -) - -filegroup( - name = "iron_menu_behavior", - srcs = [ - "index.html", - "iron-menu-behavior.html", - "iron-menubar-behavior.html", - ], -) - -filegroup( - name = "iron_meta", - srcs = [ - "index.html", - "iron-meta.html", - ], -) - -filegroup( - name = "iron_overlay_behavior", - srcs = [ - "index.html", - "iron-focusables-helper.html", - "iron-overlay-backdrop.html", - "iron-overlay-behavior.html", - "iron-overlay-manager.html", - ], -) - -filegroup( - name = "iron_range_behavior", - srcs = [ - "index.html", - "iron-range-behavior.html", - ], -) - -filegroup( - name = "iron_resizable_behavior", - srcs = [ - "demo/src/x-app.html", - "index.html", - "iron-resizable-behavior.html", - ], -) - -filegroup( - name = "iron_scroll_target_behavior", - srcs = [ - "index.html", - "iron-scroll-target-behavior.html", - ], -) - -filegroup( - name = "iron_selector", - srcs = [ - "index.html", - "iron-multi-selectable.html", - "iron-selectable.html", - "iron-selection.html", - "iron-selector.html", - ], -) - -filegroup( - name = "iron_validatable_behavior", - srcs = [ - "index.html", - "iron-validatable-behavior.html", - ], -) - -filegroup( - name = "lodash", - srcs = [ - "lodash.js", - "lodash.min.js", - ], -) - -filegroup( - name = "neon_animation", - srcs = [ - "animations/cascaded-animation.html", - "animations/fade-in-animation.html", - "animations/fade-out-animation.html", - "animations/hero-animation.html", - "animations/opaque-animation.html", - "animations/reverse-ripple-animation.html", - "animations/ripple-animation.html", - "animations/scale-down-animation.html", - "animations/scale-up-animation.html", - "animations/slide-down-animation.html", - "animations/slide-from-bottom-animation.html", - "animations/slide-from-left-animation.html", - "animations/slide-from-right-animation.html", - "animations/slide-from-top-animation.html", - "animations/slide-left-animation.html", - "animations/slide-right-animation.html", - "animations/slide-up-animation.html", - "animations/transform-animation.html", - "demo/card/index.html", - "demo/card/x-card.html", - "demo/card/x-cards-list.html", - "demo/declarative/index.html", - "demo/doc/index.html", - "demo/doc/my-animatable.html", - "demo/doc/my-dialog.html", - "demo/dropdown/animated-dropdown.html", - "demo/dropdown/index.html", - "demo/grid/animated-grid.html", - "demo/grid/fullsize-page-with-card.html", - "demo/grid/index.html", - "demo/list/full-view.html", - "demo/list/index.html", - "demo/list/list-demo.html", - "demo/list/list-view.html", - "demo/load/animated-grid.html", - "demo/load/full-page.html", - "demo/load/index.html", - "demo/reprojection/animated-grid.html", - "demo/reprojection/fullsize-page-with-card.html", - "demo/reprojection/index.html", - "demo/reprojection/reprojected-pages.html", - "demo/tiles/circles-page.html", - "demo/tiles/index.html", - "demo/tiles/squares-page.html", - "index.html", - "neon-animatable.html", - "neon-animatable-behavior.html", - "neon-animated-pages.html", - "neon-animation.html", - "neon-animation-behavior.html", - "neon-animation-runner-behavior.html", - "neon-animations.html", - "neon-shared-element-animatable-behavior.html", - "neon-shared-element-animation-behavior.html", - "web-animations.html", - ], -) - -filegroup( - name = "paper_behaviors", - srcs = [ - "index.html", - "paper-button-behavior.html", - "paper-checked-element-behavior.html", - "paper-inky-focus-behavior.html", - "paper-ripple-behavior.html", - ], -) - -filegroup( - name = "paper_button", - srcs = [ - "index.html", - "paper-button.html", - ], -) - -filegroup( - name = "paper_checkbox", - srcs = [ - "index.html", - "paper-checkbox.html", - ], -) - -filegroup( - name = "paper_dialog", - srcs = [ - "index.html", - "paper-dialog.html", - ], -) - -filegroup( - name = "paper_dialog_behavior", - srcs = [ - "index.html", - "paper-dialog-behavior.html", - "paper-dialog-common.css", - "paper-dialog-shared-styles.html", - ], -) - -filegroup( - name = "paper_dialog_scrollable", - srcs = [ - "index.html", - "paper-dialog-scrollable.html", - ], -) - -filegroup( - name = "paper_dropdown_menu", - srcs = [ - "index.html", - "paper-dropdown-menu.html", - "paper-dropdown-menu-icons.html", - "paper-dropdown-menu-light.html", - "paper-dropdown-menu-shared-styles.html", - ], -) - -filegroup( - name = "paper_header_panel", - srcs = [ - "index.html", - "paper-header-panel.html", - ], -) - -filegroup( - name = "paper_icon_button", - srcs = [ - "index.html", - "paper-icon-button.html", - "paper-icon-button-light.html", - ], -) - -filegroup( - name = "paper_input", - srcs = [ - "all-imports.html", - "index.html", - "paper-input.html", - "paper-input-addon-behavior.html", - "paper-input-behavior.html", - "paper-input-char-counter.html", - "paper-input-container.html", - "paper-input-error.html", - "paper-textarea.html", - ], -) - -filegroup( - name = "paper_item", - srcs = [ - "all-imports.html", - "index.html", - "paper-icon-item.html", - "paper-item.html", - "paper-item-behavior.html", - "paper-item-body.html", - "paper-item-shared-styles.html", - ], -) - -filegroup( - name = "paper_listbox", - srcs = [ - "index.html", - "paper-listbox.html", - ], -) - -filegroup( - name = "paper_material", - srcs = [ - "index.html", - "paper-material.html", - "paper-material-shared-styles.html", - ], -) - -filegroup( - name = "paper_menu", - srcs = [ - "index.html", - "paper-menu.html", - "paper-menu-shared-styles.html", - "paper-submenu.html", - ], -) - -filegroup( - name = "paper_menu_button", - srcs = [ - "index.html", - "paper-menu-button.html", - "paper-menu-button-animations.html", - ], -) - -filegroup( - name = "paper_progress", - srcs = [ - "index.html", - "paper-progress.html", - ], -) - -filegroup( - name = "paper_radio_button", - srcs = [ - "index.html", - "paper-radio-button.html", - ], -) - -filegroup( - name = "paper_radio_group", - srcs = [ - "index.html", - "paper-radio-group.html", - ], -) - -filegroup( - name = "paper_ripple", - srcs = [ - "index.html", - "paper-ripple.html", - ], -) - -filegroup( - name = "paper_slider", - srcs = [ - "index.html", - "paper-slider.html", - ], -) - -filegroup( - name = "paper_spinner", - srcs = [ - "index.html", - "paper-spinner.html", - "paper-spinner-behavior.html", - "paper-spinner-lite.html", - "paper-spinner-styles.html", - ], -) - -filegroup( - name = "paper_styles", - srcs = [ - "classes/global.html", - "classes/shadow.html", - "classes/shadow-layout.html", - "classes/typography.html", - "color.html", - "default-theme.html", - "demo.css", - "demo-pages.html", - "index.html", - "paper-styles.html", - "paper-styles-classes.html", - "shadow.html", - "typography.html", - ], -) - -filegroup( - name = "paper_tabs", - srcs = [ - "index.html", - "paper-tab.html", - "paper-tabs.html", - "paper-tabs-icons.html", - ], -) - -filegroup( - name = "paper_toast", - srcs = [ - "index.html", - "paper-toast.html", - ], -) - -filegroup( - name = "paper_toggle_button", - srcs = [ - "index.html", - "paper-toggle-button.html", - ], -) - -filegroup( - name = "paper_toolbar", - srcs = [ - "index.html", - "paper-toolbar.html", - ], -) - -filegroup( - name = "paper_tooltip", - srcs = [ - "index.html", - "paper-tooltip.html", - ], -) - -filegroup( - name = "plottable", - srcs = [ - "plottable.css", - "plottable.js", - "plottable.min.js", - ], -) - -filegroup( - name = "polymer", - srcs = [ - "polymer.html", - "polymer-micro.html", - "polymer-mini.html", - ], -) - -filegroup( - name = "promise_polyfill", - srcs = [ - "Gruntfile.js", - "Promise.js", - "Promise.min.js", - "Promise-Statics.js", - "promise-polyfill.html", - "promise-polyfill-lite.html", - ], -) - -filegroup( - name = "web_animations_js", - srcs = [ - "web-animations.html", - "web-animations.min.js", - "web-animations-next.min.js", - "web-animations-next-lite.min.js", - ], -) - -filegroup( - name = "webcomponentsjs", - srcs = [ - "CustomElements.js", - "CustomElements.min.js", - "HTMLImports.js", - "HTMLImports.min.js", - "MutationObserver.js", - "MutationObserver.min.js", - "ShadowDOM.js", - "ShadowDOM.min.js", - "webcomponents.js", - "webcomponents.min.js", - "webcomponents-lite.js", - "webcomponents-lite.min.js", - ], -) diff --git a/configure b/configure index 6360641be2ca99c8c8cbe58c95fc2fd59f917744..e455893ffc8539b0c9175b6acc242427b0930ce4 100755 --- a/configure +++ b/configure @@ -28,19 +28,12 @@ function is_macos() { function is_windows() { # On windows, the shell script is actually running in msys - if [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]; then - true - else - false - fi + [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] } -function sed_hyphen_i() { - if is_macos; then - sed -i '' "$@" - else - sed -i "$@" - fi +function sed_in_place() { + sed -e $1 $2 > "$2.bak" + mv "$2.bak" $2 } function write_to_bazelrc() { @@ -51,12 +44,133 @@ function write_action_env_to_bazelrc() { write_to_bazelrc "build --action_env $1=\"$2\"" } +function python_path { + "$PYTHON_BIN_PATH" - <&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + PYTHON_BIN_PATH="" + # Retry + done + + if [ -z "$PYTHON_LIB_PATH" ]; then + # Split python_path into an array of paths, this allows path containing spaces + IFS=',' + python_lib_path=($(python_path)) + unset IFS + + if [ 1 = "$USE_DEFAULT_PYTHON_LIB_PATH" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + + else + echo "Found possible Python library paths:" + for x in "${python_lib_path[@]}"; do + echo " $x" + done + set -- "${python_lib_path[@]}" + echo "Please input the desired Python library path to use. Default is ["$1"]" + read b || true + if [ "$b" == "" ]; then + PYTHON_LIB_PATH=${python_lib_path[0]} + echo "Using python library path: $PYTHON_LIB_PATH" + else + PYTHON_LIB_PATH="$b" + fi + fi + fi + + if [ ! -x "$PYTHON_BIN_PATH" ] || [ -d "$PYTHON_BIN_PATH" ]; then + echo "PYTHON_BIN_PATH is not executable. Is it the python binary?" + exit 1 + fi + + local python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);' | head -c1) + if [ -z "$python_major_version" ]; then + echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?" + exit 1 + fi + + # Convert python path to Windows style before writing into bazel.rc + if is_windows; then + PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")" + fi + + # Set-up env variables used by python_configure.bzl + write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH" + write_action_env_to_bazelrc "PYTHON_LIB_PATH" "$PYTHON_LIB_PATH" + write_to_bazelrc "build --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "build --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "build --force_python=py$python_major_version" + write_to_bazelrc "build --host_force_python=py$python_major_version" + write_to_bazelrc "build --python${python_major_version}_path=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --force_python=py$python_major_version" + write_to_bazelrc "test --host_force_python=py$python_major_version" + write_to_bazelrc "test --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "test --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + write_to_bazelrc "run --define PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" + write_to_bazelrc "run --define PYTHON_LIB_PATH=\"$PYTHON_LIB_PATH\"" + + # Write tools/python_bin_path.sh + echo "export PYTHON_BIN_PATH=\"$PYTHON_BIN_PATH\"" > tools/python_bin_path.sh +} + # This file contains customized config settings. rm -f .tf_configure.bazelrc touch .tf_configure.bazelrc -touch .bazelrc -sed_hyphen_i "/tf_configure/d" .bazelrc -echo "import .tf_configure.bazelrc" >> .bazelrc +if [[ ! -e .bazelrc ]]; then + if [[ -e "${HOME}/.bazelrc" ]]; then + echo "import ${HOME}/.bazelrc" >.bazelrc + else + touch .bazelrc + fi +fi +sed_in_place "/tf_configure/d" .bazelrc +echo "import %workspace%/.tf_configure.bazelrc" >> .bazelrc # Delete any leftover BUILD files from the Makefile build, which would interfere # with Bazel parsing. @@ -65,58 +179,63 @@ if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete fi -## Set up python-related environment settings -while true; do +setup_python + +## Set up MKL related environment settings +while [ "$TF_NEED_MKL" == "" ]; do fromuser="" - if [ -z "$PYTHON_BIN_PATH" ]; then - default_python_bin_path=$(which python || which python3 || true) - read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH - fromuser="1" - if [ -z "$PYTHON_BIN_PATH" ]; then - PYTHON_BIN_PATH=$default_python_bin_path - fi - fi - if [ -e "$PYTHON_BIN_PATH" ]; then - break - fi - echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - PYTHON_BIN_PATH="" - # Retry + read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + fromuser="1" + case $INPUT in + [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; + [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; + * ) echo "Invalid selection: " $INPUT;; + esac done -## Set up MKL related environment settings -if false; then # Disable building with MKL for now - while [ "$TF_NEED_MKL" == "" ]; do +OSNAME=`uname -s` + +if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + while [ "$TF_DOWNLOAD_MKL" == "" ]; do fromuser="" - read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT + read -p "Do you wish to download MKL LIB from the web? [Y/n] " INPUT fromuser="1" case $INPUT in - [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;; - [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;; - * ) echo "Invalid selection: " $INPUT;; + [Yy]* ) TF_DOWNLOAD_MKL=1;; + [Nn]* ) TF_DOWNLOAD_MKL=0;; + "" ) TF_DOWNLOAD_MKL=1;; + * ) echo "Invalid selection: " $INPUT; exit 1;; esac done - OSNAME=`uname -s` - - if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL + if [[ "$TF_DOWNLOAD_MKL" == "1" ]]; then DST=`dirname $0` - ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170209.tgz - GITHUB_RELEASE_TAG=v0.5 + ARCHIVE_BASENAME=mklml_lnx_2018.0.20170425.tgz + GITHUB_RELEASE_TAG=v0.7 MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME" - if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then - wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL + if ! [ -e "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" ]; then + curl -fSsL -o "${DST}/third_party/mkl/${ARCHIVE_BASENAME}" "${MKLURL}" fi tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/ extracted_dir_name="${ARCHIVE_BASENAME%.*}" MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` - if [ "$OSNAME" == "Linux" ]; then + else + default_mkl_path=/opt/intel/mklml + fromuser="" + read -p "Please specify the location where MKL is installed. [Default is $default_mkl_path]: " MKL_INSTALL_PATH + fromuser="1" + if [ -z "$MKL_INSTALL_PATH" ]; then + MKL_INSTALL_PATH=$default_mkl_path + fi + # Result returned from "read" will be used unexpanded. That make "~" unusable. + # Going through one more level of expansion to handle that. + MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"` + fi + + if [ "$OSNAME" == "Linux" ]; then # Full MKL configuration MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version? MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION? @@ -124,24 +243,29 @@ if false; then # Disable building with MKL for now # MKL-ML configuration MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version? MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION? - elif [ "$OSNAME" == "Darwin" ]; then + elif [ "$OSNAME" == "Darwin" ]; then echo "Darwin is unsupported yet"; exit 1 - fi + fi - if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then + if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include - else - echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist"; - exit 1 - fi - - if [ -z "$fromuser" ]; then + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + elif [ -e "$MKL_INSTALL_PATH/${MKL_RT_LIB_PATH}" ]; then + ln -sf $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/${MKL_RT_OMP_LIB_PATH} third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/mkl/ + ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include + loc=$(locate -e libdl.so.2 | sed -n 1p) + ln -sf $loc third_party/mkl/libdl.so.2 + else + echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} nor $MKL_INSTALL_PATH/${MKL_RT_LIB_PATH} exists"; exit 1 - fi + fi cat > third_party/mkl/mkl.config < third_party/mkl/mkl.config <> tools/bazel.rc for opt in $CC_OPT_FLAGS; do - echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc + write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt" done # Run the gen_git_source to create links where bazel can track dependencies for @@ -284,6 +421,7 @@ export TF_NEED_CUDA write_action_env_to_bazelrc "TF_NEED_CUDA" "$TF_NEED_CUDA" export TF_NEED_OPENCL +write_action_env_to_bazelrc "TF_NEED_OPENCL" "$TF_NEED_OPENCL" if [ "$TF_NEED_CUDA" == "1" ]; then while [[ "$TF_CUDA_CLANG" == "" ]]; do @@ -299,31 +437,6 @@ done export TF_CUDA_CLANG write_action_env_to_bazelrc "TF_CUDA_CLANG" "$TF_CUDA_CLANG" -# Set up which gcc nvcc should use as the host compiler -# No need to set this on Windows -while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do - fromuser="" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - default_gcc_host_compiler_path=$(which gcc || true) - read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH - fromuser="1" - if [ -z "$GCC_HOST_COMPILER_PATH" ]; then - GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" - fi - fi - if [ -e "$GCC_HOST_COMPILER_PATH" ]; then - export GCC_HOST_COMPILER_PATH - write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" - break - fi - echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 - if [ -z "$fromuser" ]; then - exit 1 - fi - GCC_HOST_COMPILER_PATH="" - # Retry -done - # Set up which clang we should use as the cuda / host compiler. while [[ "$TF_CUDA_CLANG" == "1" ]] && true; do fromuser="" @@ -364,6 +477,11 @@ while true; do else default_cuda_path="$(cygpath -m "$CUDA_PATH")" fi + elif is_linux; then + # If the default doesn't exist, try an alternative default. + if [ ! -d $default_cuda_path ] && [ -d /opt/cuda ]; then + default_cuda_path=/opt/cuda + fi fi read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH fromuser="1" @@ -403,6 +521,35 @@ while true; do CUDA_TOOLKIT_PATH="" done +# Set up which gcc nvcc should use as the host compiler +# No need to set this on Windows +while [[ "$TF_CUDA_CLANG" != "1" ]] && ! is_windows && true; do + fromuser="" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + default_gcc_host_compiler_path=$(which gcc || true) + cuda_bin_symlink="$CUDA_TOOLKIT_PATH/bin/gcc" + if [ -L "$cuda_bin_symlink" ]; then + default_gcc_host_compiler_path=$(readlink $cuda_bin_symlink) + fi + read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH + fromuser="1" + if [ -z "$GCC_HOST_COMPILER_PATH" ]; then + GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path" + fi + fi + if [ -e "$GCC_HOST_COMPILER_PATH" ]; then + export GCC_HOST_COMPILER_PATH + write_action_env_to_bazelrc "GCC_HOST_COMPILER_PATH" "$GCC_HOST_COMPILER_PATH" + break + fi + echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2 + if [ -z "$fromuser" ]; then + exit 1 + fi + GCC_HOST_COMPILER_PATH="" + # Retry +done + # Find out where the cuDNN library is installed while true; do # Configure the cuDNN version to use. @@ -418,7 +565,7 @@ while true; do if [ -z "$CUDNN_INSTALL_PATH" ]; then CUDNN_INSTALL_PATH=$default_cudnn_path fi - # Result returned from "read" will be used unexpanded. That make "~" unuseable. + # Result returned from "read" will be used unexpanded. That make "~" unusable. # Going through one more level of expansion to handle that. CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"` fi @@ -547,6 +694,7 @@ while true; do fi if [ -e "$HOST_CXX_COMPILER" ]; then export HOST_CXX_COMPILER + write_action_env_to_bazelrc "HOST_CXX_COMPILER" "$HOST_CXX_COMPILER" break fi echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2 @@ -570,6 +718,7 @@ while true; do fi if [ -e "$HOST_C_COMPILER" ]; then export HOST_C_COMPILER + write_action_env_to_bazelrc "HOST_C_COMPILER" "$HOST_C_COMPILER" break fi echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2 @@ -600,6 +749,7 @@ while true; do if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then export COMPUTECPP_TOOLKIT_PATH + write_action_env_to_bazelrc "COMPUTECPP_TOOLKIT_PATH" "$COMPUTECPP_TOOLKIT_PATH" break fi echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found" diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6a70c0e4057ed35bcdb10157e4147de35546b6a9..54da5bf3fee8b03c1b0ed34890c193575e324617 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -14,9 +14,7 @@ exports_files([ # Config setting for determining if we are building for Android. config_setting( name = "android", - values = { - "crosstool_top": "//external:android/crosstool", - }, + values = {"crosstool_top": "//external:android/crosstool"}, visibility = ["//visibility:public"], ) @@ -76,9 +74,7 @@ config_setting( config_setting( name = "ios", - values = { - "crosstool_top": "//tools/osx/crosstool:crosstool", - }, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) @@ -88,6 +84,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_ppc64le", + values = {"cpu": "ppc"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -112,7 +114,7 @@ config_setting( # TODO(jhseu): Enable on other platforms other than Linux. config_setting( - name = "with_jemalloc", + name = "with_jemalloc_linux_x86_64", values = { "cpu": "k8", "define": "with_jemalloc=true", @@ -120,6 +122,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_jemalloc_linux_ppc64le", + values = { + "cpu": "ppc", + "define": "with_jemalloc=true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support", values = {"define": "with_gcp_support=true"}, @@ -138,6 +149,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_verbs_support", + values = {"define": "with_verbs_support=true"}, + visibility = ["//visibility:public"], +) + package_group( name = "internal", packages = ["//tensorflow/..."], @@ -185,7 +202,6 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/port:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -196,18 +212,24 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/batching:all_files", + "//tensorflow/contrib/batching/kernels:all_files", "//tensorflow/contrib/batching/test_util:all_files", "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/boosted_trees:all_files", "//tensorflow/contrib/boosted_trees/lib:all_files", "//tensorflow/contrib/boosted_trees/proto:all_files", + "//tensorflow/contrib/boosted_trees/resources:all_files", "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", "//tensorflow/contrib/cudnn_rnn:all_files", + "//tensorflow/contrib/data:all_files", + "//tensorflow/contrib/data/python/framework:all_files", + "//tensorflow/contrib/data/python/kernel_tests:all_files", + "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", @@ -217,6 +239,7 @@ filegroup( "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", "//tensorflow/contrib/hooks:all_files", + "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", "//tensorflow/contrib/image:all_files", "//tensorflow/contrib/imperative:all_files", "//tensorflow/contrib/input_pipeline:all_files", @@ -239,16 +262,20 @@ filegroup( "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/rnn:all_files", "//tensorflow/contrib/saved_model:all_files", + "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", "//tensorflow/contrib/session_bundle:all_files", "//tensorflow/contrib/session_bundle/example:all_files", + "//tensorflow/contrib/signal:all_files", "//tensorflow/contrib/slim:all_files", "//tensorflow/contrib/slim/python/slim/data:all_files", "//tensorflow/contrib/slim/python/slim/nets:all_files", "//tensorflow/contrib/solvers:all_files", "//tensorflow/contrib/sparsemax:all_files", "//tensorflow/contrib/specs:all_files", + "//tensorflow/contrib/staging:all_files", "//tensorflow/contrib/stat_summarizer:all_files", + "//tensorflow/contrib/stateless:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", "//tensorflow/contrib/tensorboard:all_files", @@ -256,6 +283,8 @@ filegroup( "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", + "//tensorflow/contrib/verbs:all_files", + "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", @@ -265,8 +294,10 @@ filegroup( "//tensorflow/core/grappler/costs:all_files", "//tensorflow/core/grappler/inputs:all_files", "//tensorflow/core/grappler/optimizers:all_files", + "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", "//tensorflow/core/kernels/hexagon:all_files", + "//tensorflow/core/kernels/neon:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", @@ -274,6 +305,7 @@ filegroup( "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", "//tensorflow/examples/android:all_files", + "//tensorflow/examples/benchmark:all_files", "//tensorflow/examples/how_tos/reading_data:all_files", "//tensorflow/examples/image_retraining:all_files", "//tensorflow/examples/label_image:all_files", @@ -282,6 +314,7 @@ filegroup( "//tensorflow/examples/tutorials/estimators:all_files", "//tensorflow/examples/tutorials/mnist:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", + "//tensorflow/examples/wav_to_spectrogram:all_files", "//tensorflow/go:all_files", "//tensorflow/java:all_files", "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", @@ -289,27 +322,67 @@ filegroup( "//tensorflow/python:all_files", "//tensorflow/python/debug:all_files", "//tensorflow/python/estimator:all_files", + "//tensorflow/python/feature_column:all_files", "//tensorflow/python/kernel_tests:all_files", + "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", "//tensorflow/tensorboard:all_files", - "//tensorflow/tensorboard/app:all_files", "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/backend/event_processing:all_files", "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", - "//tensorflow/tensorboard/components/vz_data_summary:all_files", - "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files", - "//tensorflow/tensorboard/components/vz_projector:all_files", - "//tensorflow/tensorboard/components/vz_sorting:all_files", - "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/lib:all_files", + "//tensorflow/tensorboard/components/tf_audio_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_backend_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_backend_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_color_scale_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_distribution_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_app_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_app_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_board_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_board_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_common_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_controls_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_dashboard_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_info_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_info_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_graph_loader_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_histogram_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_image_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_imports_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_option_selector_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_scalar_dashboard_d3v4/demo:all_files", + "//tensorflow/tensorboard/components/tf_storage_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_storage_d3v4/test:all_files", + "//tensorflow/tensorboard/components/tf_tensorboard_d3v4:all_files", + "//tensorflow/tensorboard/components/tf_text_dashboard_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_heatmap_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_projector_d3v4/test:all_files", + "//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files", + "//tensorflow/tensorboard/components/vz_sorting_d3v4/test:all_files", + "//tensorflow/tensorboard/demo:all_files", + "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", "//tensorflow/tensorboard/plugins:all_files", - "//tensorflow/tensorboard/plugins/debugger:all_files", "//tensorflow/tensorboard/plugins/projector:all_files", "//tensorflow/tensorboard/plugins/text:all_files", "//tensorflow/tensorboard/scripts:all_files", + "//tensorflow/tools/api/golden:all_files", + "//tensorflow/tools/api/lib:all_files", + "//tensorflow/tools/api/tests:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", "//tensorflow/tools/dist_test/server:all_files", @@ -344,14 +417,34 @@ filegroup( ), ) +filegroup( + name = "docs_src", + data = glob(["docs_src/**/*.md"]), +) + # ------------------------------------------- # New rules should be added above this target. # ------------------------------------------- cc_binary( name = "libtensorflow.so", + linkopts = select({ + "//tensorflow:darwin": [ + "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file + "//tensorflow/c:exported_symbols.lds", + ], + "//tensorflow:windows": [], + "//conditions:default": [ + "-z defs", + "-s", + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "//tensorflow/c:version_script.lds", + ], + }), linkshared = 1, deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:exported_symbols.lds", + "//tensorflow/c:version_script.lds", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 0bca6f8fb8051925908db5e86f30d97d534e60f4..083634bd7964b0c12e10a1f3c71be5eab597a6c4 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -24,19 +24,9 @@ from __future__ import print_function from tensorflow.python import * # pylint: enable=wildcard-import -# Lazily import the `tf.contrib` module. This avoids loading all of the -# dependencies of `tf.contrib` at `import tensorflow` time. -class _LazyContribLoader(object): - - def __getattr__(self, item): - global contrib - # Replace the lazy loader with the imported module itself. - import importlib # pylint: disable=g-import-not-at-top - contrib = importlib.import_module('tensorflow.contrib') - return getattr(contrib, item) - - -contrib = _LazyContribLoader() +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader del absolute_import del division diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 0019dfeeb13f5e591d44dd37d73a93ce64a92d95..3ab4e8efcdb5b05cf8922edd302e7cbf3a3597f1 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -26,6 +26,22 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +tf_cuda_library( + name = "c_api_internal", + srcs = ["c_api.h"], + hdrs = ["c_api_internal.h"], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + }), +) + tf_cuda_library( name = "c_api", srcs = ["c_api.cc"], @@ -34,10 +50,16 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ + ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + ":c_api_internal", "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc:gradients", + "//tensorflow/cc:ops", + "//tensorflow/cc:grad_ops", + "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -45,6 +67,14 @@ tf_cuda_library( }), ) +exports_files( + [ + "version_script.lds", + "exported_symbols.lds", + ], + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "tf_status_helper", srcs = ["tf_status_helper.cc"], @@ -89,21 +119,22 @@ tf_cc_test( # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:math", - "//third_party/eigen3", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index d4bcc01b6b89329ad8149e2e98ac2df5d1c15882..f4775783f9f88c941445b62603c92cae00d34715 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -21,8 +21,12 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/saved_model/loader.h" #endif +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def_util.h" @@ -93,9 +97,6 @@ size_t TF_DataTypeSize(TF_DataType dt) { } // -------------------------------------------------------------------------- -struct TF_Status { - Status status; -}; TF_Status* TF_NewStatus() { return new TF_Status; } @@ -179,12 +180,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } // namespace -struct TF_Tensor { - TF_DataType dtype; - TensorShape shape; - TensorBuffer* buffer; -}; - TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, int num_dims, size_t len) { void* data = allocate_tensor("TF_AllocateTensor", len); @@ -220,6 +215,18 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, return new TF_Tensor{dtype, TensorShape(dimvec), buf}; } +TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { + // It is safe to move the Tensor if and only if we own the unique reference to + // it. In that case, we might as well not delete and reallocate, but a future + // implementation might need to do so. + if (tensor->buffer->RefCountIsOne() && + tensor->buffer->root_buffer()->RefCountIsOne() && + tensor->buffer->OwnsMemory()) { + return tensor; + } + return nullptr; +} + void TF_DeleteTensor(TF_Tensor* t) { t->buffer->Unref(); delete t; @@ -277,9 +284,6 @@ size_t TF_StringEncodedSize(size_t len) { } // -------------------------------------------------------------------------- -struct TF_SessionOptions { - SessionOptions options; -}; TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } @@ -320,9 +324,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) { TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } // -------------------------------------------------------------------------- -struct TF_DeprecatedSession { - Session* session; -}; TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt, TF_Status* status) { @@ -654,6 +655,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; } else { + *handle = nullptr; status->status = result; } } @@ -685,11 +687,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, c_outputs, target_oper_names, nullptr, status); } -struct TF_Library { - void* lib_handle; - TF_Buffer op_list; -}; - TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Library* lib_handle = new TF_Library; status->status = tensorflow::LoadLibrary( @@ -726,66 +723,6 @@ TF_Buffer* TF_GetAllOpList() { // -------------------------------------------------------------------------- // New Graph and Session API -// Structures ----------------------------------------------------------------- - -extern "C" { - -struct TF_Graph { - TF_Graph() - : graph(OpRegistry::Global()), - refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), - delete_requested(false), - parent(nullptr), - parent_inputs(nullptr) {} - mutex mu; - Graph graph GUARDED_BY(mu); - - // Runs shape inference. - tensorflow::ShapeRefiner refiner GUARDED_BY(mu); - - // Maps from name of an operation to the Node* in 'graph'. - std::unordered_map name_map GUARDED_BY(mu); - - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by - // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); - bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph - - // Used to link graphs contained in TF_WhileParams to the parent graph that - // will eventually contain the full while loop. - TF_Graph* parent; - TF_Output* parent_inputs; -}; - -struct TF_OperationDescription { - TF_OperationDescription(TF_Graph* g, const char* op_type, - const char* node_name) - : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} - - NodeBuilder node_builder; - TF_Graph* graph; - std::vector colocation_constraints; -}; - -struct TF_Operation { - Node node; -}; - -struct TF_Session { - TF_Session(Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0) {} - Session* session; - TF_Graph* graph; - mutex mu; - int last_num_graph_nodes; -}; - -} // end extern "C" - // Helper functions ----------------------------------------------------------- namespace { @@ -801,8 +738,7 @@ tensorflow::string OutputName(const TF_Output& output) { const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, const char* attr_name, TF_Status* status) { - const tensorflow::AttrValue* attr = - tensorflow::AttrSlice(oper->node.def()).Find(attr_name); + const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { status->status = InvalidArgument("Operation has no attr named '", attr_name, "'."); @@ -1164,14 +1100,14 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, if (status->status.ok()) { // Run shape inference function for newly added node. - // - // TODO(b/28152992): Enable returning the result of this - // code-path once we have converted all python shape functions - // to call their C++ versions. - desc->graph->refiner.AddNode(ret).IgnoreError(); - + status->status = desc->graph->refiner.AddNode(ret); + } + if (status->status.ok()) { // Add the node to the name-to-node mapping. desc->graph->name_map[ret->name()] = ret; + } else if (ret != nullptr) { + desc->graph->graph.RemoveNode(ret); + ret = nullptr; } } @@ -1198,7 +1134,7 @@ const char* TF_OperationOpType(TF_Operation* oper) { } const char* TF_OperationDevice(TF_Operation* oper) { - return oper->node.def().device().c_str(); + return oper->node.requested_device().c_str(); } int TF_OperationNumOutputs(TF_Operation* oper) { @@ -1213,8 +1149,8 @@ TF_DataType TF_OperationOutputType(TF_Output oper_out) { int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - nullptr, &name_ranges); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1235,8 +1171,8 @@ TF_DataType TF_OperationInputType(TF_Input oper_in) { int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, TF_Status* status) { NameRangeMap name_ranges; - status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(), - &name_ranges, nullptr); + status->status = + NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr); if (!status->status.ok()) return -1; auto iter = name_ranges.find(arg_name); if (iter == name_ranges.end()) { @@ -1474,26 +1410,27 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, } } -#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ - void func(TF_Operation* oper, const char* attr_name, c_type* value, \ - TF_Status* status) { \ - cpp_type v; \ - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &v); \ - *value = static_cast(v); \ - } \ - void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ - int max_values, TF_Status* status) { \ - const auto* attr = GetAttrValue(oper, attr_name, status); \ - if (!status->status.ok()) return; \ - if (attr->value_case() != tensorflow::AttrValue::kList) { \ - status->status = \ - InvalidArgument("Value for '", attr_name, "' is not a list."); \ - return; \ - } \ - const auto len = std::min(max_values, attr->list().list_field##_size()); \ - for (int i = 0; i < len; ++i) { \ - values[i] = static_cast(attr->list().list_field(i)); \ - } \ +#define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \ + void func(TF_Operation* oper, const char* attr_name, c_type* value, \ + TF_Status* status) { \ + cpp_type v; \ + status->status = \ + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ + *value = static_cast(v); \ + } \ + void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ + int max_values, TF_Status* status) { \ + const auto* attr = GetAttrValue(oper, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + const auto len = std::min(max_values, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + values[i] = static_cast(attr->list().list_field(i)); \ + } \ } DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i); DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f); @@ -1504,7 +1441,8 @@ DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type); void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, int64_t* value, int num_dims, TF_Status* status) { PartialTensorShape shape; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shape); + status->status = + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape); if (!status->status.ok()) return; auto len = std::min(shape.dims(), num_dims); for (int i = 0; i < len; ++i) { @@ -1518,7 +1456,7 @@ void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name, int storage_size, TF_Status* status) { std::vector shapes; status->status = - tensorflow::GetNodeAttr(oper->node.def(), attr_name, &shapes); + tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes); if (!status->status.ok()) return; auto len = std::min(static_cast(shapes.size()), max_values); int64_t* p = storage; @@ -1585,7 +1523,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, TF_Tensor** value, TF_Status* status) { *value = nullptr; Tensor t; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &t); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; *value = new TF_Tensor{static_cast(t.dtype()), t.shape(), tensorflow::TensorCApi::Buffer(t)}; @@ -1596,7 +1534,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, TF_Tensor** values, int max_values, TF_Status* status) { std::vector ts; - status->status = tensorflow::GetNodeAttr(oper->node.def(), attr_name, &ts); + status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts); if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast(ts.size())); for (int i = 0; i < len; ++i) { @@ -1675,10 +1613,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, status->status = MessageToBuffer(def, output_graph_def); } -struct TF_ImportGraphDefOptions { - tensorflow::ImportGraphDefOptions opts; -}; - TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } @@ -2101,6 +2035,75 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } +#ifndef __ANDROID__ +namespace { + +void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, + std::vector* outputs) { + outputs->resize(n); + for (int i = 0; i < n; i++) { + const TF_Output& tf_output = tf_outputs[i]; + (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); + } +} + +void TFOutputsFromOutputs(const std::vector& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} + +} // namespace +#endif // __ANDROID__ + +void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, TF_Output* dy) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Adding gradients is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); +#else + std::vector y_arg; + std::vector x_arg; + std::vector dy_arg; + OutputsFromTFOutputs(y, ny, status, &y_arg); + OutputsFromTFOutputs(x, nx, status, &x_arg); + + { + // We need to hold on to the lock while we have a scope that uses TF_Graph. + mutex_lock graph_lock(g->mu); + + const int max_node_id_before = g->graph.num_node_ids(); + + tensorflow::Scope scope = + NewInternalScope(&g->graph, &status->status, &g->refiner); + + if (dx != nullptr) { + std::vector dx_arg; + OutputsFromTFOutputs(dx, ny, status, &dx_arg); + status->status = + AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); + } else { + status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg); + } + + // Update g->name_map with the name_map from the scope, which will contain + // the new gradient ops. + for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) { + Node* n = g->graph.FindNodeId(i); + if (n == nullptr) continue; + g->name_map[n->name()] = n; + } + } + + // Unpack the results from grad_outputs_arg. + TFOutputsFromOutputs(dy_arg, dy); +#endif // __ANDROID__ +} + // TF_Session functions ---------------------------------------------- TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index f837b68d76c34ba836720df820daaae5bc29c93c..ec9b01b388d1138644e28e3206e32726347b3d5e 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -64,6 +64,25 @@ limitations under the License. // and the API just provides high level controls over the number of // devices of each type. +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(COMPILER_MSVC) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // COMPILER_MSVC +#endif // SWIG + #ifdef __cplusplus extern "C" { #endif @@ -71,12 +90,12 @@ extern "C" { // -------------------------------------------------------------------------- // TF_Version returns a string describing version information of the // TensorFlow library. TensorFlow using semantic versioning. -extern const char* TF_Version(); +TF_CAPI_EXPORT extern const char* TF_Version(); // -------------------------------------------------------------------------- // TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. // The enum values here are identical to corresponding values in types.proto. -typedef enum { +typedef enum TF_DataType { TF_FLOAT = 1, TF_DOUBLE = 2, TF_INT32 = 3, // Int32 tensors are always in 'host' memory. @@ -103,12 +122,12 @@ typedef enum { // TF_DataTypeSize returns the sizeof() for the underlying type corresponding // to the given TF_DataType enum value. Returns 0 for variable length types // (eg. TF_STRING) or on failure. -extern size_t TF_DataTypeSize(TF_DataType dt); +TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. -typedef enum { +typedef enum TF_Code { TF_OK = 0, TF_CANCELLED = 1, TF_UNKNOWN = 2, @@ -134,23 +153,24 @@ typedef enum { typedef struct TF_Status TF_Status; // Return a new status object. -extern TF_Status* TF_NewStatus(); +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(); // Delete a previously created status object. -extern void TF_DeleteStatus(TF_Status*); +TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); // Record in *s. Any previous information is lost. // A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); -extern void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg); +TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code, + const char* msg); // Return the code record in *s. -extern TF_Code TF_GetCode(const TF_Status* s); +TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s); // Return a pointer to the (null-terminated) error message in *s. The // return value points to memory that is only usable until the next // mutation to *s. Always returns an empty string if TF_GetCode(s) is // TF_OK. -extern const char* TF_Message(const TF_Status* s); +TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s); // -------------------------------------------------------------------------- // TF_Buffer holds a pointer to a block of data and its associated length. @@ -168,14 +188,15 @@ typedef struct TF_Buffer { // Makes a copy of the input and sets an appropriate deallocator. Useful for // passing in read-only, input protobufs. -extern TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, + size_t proto_len); // Useful for passing *out* a protobuf. -extern TF_Buffer* TF_NewBuffer(); +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(); -extern void TF_DeleteBuffer(TF_Buffer*); +TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); -extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); +TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); // -------------------------------------------------------------------------- // TF_Tensor holds a multi-dimensional array of elements of a single data type. @@ -202,11 +223,10 @@ typedef struct TF_Tensor TF_Tensor; // (*deallocator)(data, len, deallocator_arg) // Clients must provide a custom deallocator function so they can pass in // memory managed by something like numpy. -extern TF_Tensor* TF_NewTensor(TF_DataType, const int64_t* dims, int num_dims, - void* data, size_t len, - void (*deallocator)(void* data, size_t len, - void* arg), - void* deallocator_arg); +TF_CAPI_EXPORT extern TF_Tensor* TF_NewTensor( + TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg); // Allocate and return a new Tensor. // @@ -217,27 +237,32 @@ extern TF_Tensor* TF_NewTensor(TF_DataType, const int64_t* dims, int num_dims, // // The caller must set the Tensor values by writing them to the pointer returned // by TF_TensorData with length TF_TensorByteSize. -extern TF_Tensor* TF_AllocateTensor(TF_DataType, const int64_t* dims, - int num_dims, size_t len); +TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTensor(TF_DataType, + const int64_t* dims, + int num_dims, size_t len); + +// Deletes `tensor` and returns a new TF_Tensor with the same content if +// possible. Returns nullptr and leaves `tensor` untouched if not. +TF_CAPI_EXPORT extern TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor); // Destroy a tensor. -extern void TF_DeleteTensor(TF_Tensor*); +TF_CAPI_EXPORT extern void TF_DeleteTensor(TF_Tensor*); // Return the type of a tensor element. -extern TF_DataType TF_TensorType(const TF_Tensor*); +TF_CAPI_EXPORT extern TF_DataType TF_TensorType(const TF_Tensor*); // Return the number of dimensions that the tensor has. -extern int TF_NumDims(const TF_Tensor*); +TF_CAPI_EXPORT extern int TF_NumDims(const TF_Tensor*); // Return the length of the tensor in the "dim_index" dimension. // REQUIRES: 0 <= dim_index < TF_NumDims(tensor) -extern int64_t TF_Dim(const TF_Tensor* tensor, int dim_index); +TF_CAPI_EXPORT extern int64_t TF_Dim(const TF_Tensor* tensor, int dim_index); // Return the size of the underlying data in bytes. -extern size_t TF_TensorByteSize(const TF_Tensor*); +TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); // Return a pointer to the underlying data buffer. -extern void* TF_TensorData(const TF_Tensor*); +TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); // -------------------------------------------------------------------------- // Encode the string `src` (`src_len` bytes long) into `dst` in the format @@ -247,8 +272,9 @@ extern void* TF_TensorData(const TF_Tensor*); // // On success returns the size in bytes of the encoded string. // Returns an error into `status` otherwise. -extern size_t TF_StringEncode(const char* src, size_t src_len, char* dst, - size_t dst_len, TF_Status* status); +TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len, + char* dst, size_t dst_len, + TF_Status* status); // Decode a string encoded using TF_StringEncode. // @@ -258,19 +284,20 @@ extern size_t TF_StringEncode(const char* src, size_t src_len, char* dst, // `*dst` and `*dst_len` are undefined and an error is set in `status`. // // Does not read memory more than `src_len` bytes beyond `src`. -extern size_t TF_StringDecode(const char* src, size_t src_len, const char** dst, - size_t* dst_len, TF_Status* status); +TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len, + const char** dst, size_t* dst_len, + TF_Status* status); // Return the size in bytes required to encode a string `len` bytes long into a // TF_STRING tensor. -extern size_t TF_StringEncodedSize(size_t len); +TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); // -------------------------------------------------------------------------- // TF_SessionOptions holds options that can be passed during session creation. typedef struct TF_SessionOptions TF_SessionOptions; // Return a new options object. -extern TF_SessionOptions* TF_NewSessionOptions(); +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(); // Set the target in TF_SessionOptions.options. // target can be empty, a single entry, or a comma separated list of entries. @@ -278,17 +305,19 @@ extern TF_SessionOptions* TF_NewSessionOptions(); // "local" // ip:port // host:port -extern void TF_SetTarget(TF_SessionOptions* options, const char* target); +TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, + const char* target); // Set the config in TF_SessionOptions.options. // config should be a serialized tensorflow.ConfigProto proto. // If config was not parsed successfully as a ConfigProto, record the // error information in *status. -extern void TF_SetConfig(TF_SessionOptions* options, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, + const void* proto, size_t proto_len, + TF_Status* status); // Destroy an options object. -extern void TF_DeleteSessionOptions(TF_SessionOptions*); +TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); // TODO(jeff,sanjay): // - export functions to set Config fields @@ -301,11 +330,11 @@ extern void TF_DeleteSessionOptions(TF_SessionOptions*); typedef struct TF_Graph TF_Graph; // Return a new graph object. -extern TF_Graph* TF_NewGraph(); +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(); // Destroy an options object. Graph will be deleted once no more // TFSession's are referencing it. -extern void TF_DeleteGraph(TF_Graph*); +TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); // Operation being built. The underlying graph must outlive this. typedef struct TF_OperationDescription TF_OperationDescription; @@ -343,9 +372,11 @@ typedef struct TF_Output { // * `output` is not in `graph`. // * An invalid shape is being set (e.g., the shape being set // is incompatible with the existing shape). -extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, - const int64_t* dims, const int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, + TF_Output output, + const int64_t* dims, + const int num_dims, + TF_Status* status); // Returns the number of dimensions of the Tensor referenced by `output` // in `graph`. @@ -354,8 +385,9 @@ extern void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, // // Returns an error into `status` if: // * `output` is not in `graph`. -extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, + TF_Output output, + TF_Status* status); // Returns the shape of the Tensor referenced by `output` in `graph` // into `dims`. `dims` must be an array large enough to hold `num_dims` @@ -369,20 +401,21 @@ extern int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output, // Returns an error into `status` if: // * `output` is not in `graph`. // * `num_dims` does not match the actual number of dimensions. -extern void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, - int64_t* dims, int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, + TF_Output output, + int64_t* dims, int num_dims, + TF_Status* status); // Operation will only be added to *graph when TF_FinishOperation() is // called (assuming TF_FinishOperation() does not return an error). // *graph must not be deleted until after TF_FinishOperation() is // called. -extern TF_OperationDescription* TF_NewOperation(TF_Graph* graph, - const char* op_type, - const char* oper_name); +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( + TF_Graph* graph, const char* op_type, const char* oper_name); // Specify the device for `desc`. Defaults to empty, meaning unconstrained. -extern void TF_SetDevice(TF_OperationDescription* desc, const char* device); +TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, + const char* device); // The calls to TF_AddInput and TF_AddInputList must match (in number, // order, and type) the op declaration. For example, the "Concat" op @@ -405,101 +438,115 @@ extern void TF_SetDevice(TF_OperationDescription* desc, const char* device); // TF_AddInputList(desc, values_inputs, 5); // For inputs that take a single tensor. -extern void TF_AddInput(TF_OperationDescription* desc, TF_Output input); +TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, + TF_Output input); // For inputs that take a list of tensors. // inputs must point to TF_Output[num_inputs]. -extern void TF_AddInputList(TF_OperationDescription* desc, - const TF_Output* inputs, int num_inputs); +TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, + const TF_Output* inputs, + int num_inputs); // Call once per control input to `desc`. -extern void TF_AddControlInput(TF_OperationDescription* desc, - TF_Operation* input); +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, + TF_Operation* input); // Request that `desc` be co-located on the device where `op` // is placed. // // Use of this is discouraged since the implementation of device placement is // subject to change. Primarily intended for internal libraries -extern void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op); +TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, + TF_Operation* op); // Call some TF_SetAttr*() function for every attr that is not // inferred from an input and doesn't have a default value you wish to // keep. // `value` must point to a string of length `length` bytes. -extern void TF_SetAttrString(TF_OperationDescription* desc, - const char* attr_name, const void* value, - size_t length); +TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, + const char* attr_name, + const void* value, size_t length); // `values` and `lengths` each must have lengths `num_values`. // `values[i]` must point to a string of length `lengths[i]` bytes. -extern void TF_SetAttrStringList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* values, - const size_t* lengths, int num_values); -extern void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name, - int64_t value); -extern void TF_SetAttrIntList(TF_OperationDescription* desc, - const char* attr_name, const int64_t* values, - int num_values); -extern void TF_SetAttrFloat(TF_OperationDescription* desc, - const char* attr_name, float value); -extern void TF_SetAttrFloatList(TF_OperationDescription* desc, - const char* attr_name, const float* values, - int num_values); -extern void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name, - unsigned char value); -extern void TF_SetAttrBoolList(TF_OperationDescription* desc, - const char* attr_name, - const unsigned char* values, int num_values); -extern void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name, - TF_DataType value); -extern void TF_SetAttrTypeList(TF_OperationDescription* desc, - const char* attr_name, const TF_DataType* values, - int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, + const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, + const char* attr_name, float value); +TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, + const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, + const char* attr_name, + const TF_DataType* values, + int num_values); // Set `num_dims` to -1 to represent "unknown rank". Otherwise, // `dims` points to an array of length `num_dims`. `dims[i]` must be // >= -1, with -1 meaning "unknown dimension". -extern void TF_SetAttrShape(TF_OperationDescription* desc, - const char* attr_name, const int64_t* dims, - int num_dims); +TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* dims, int num_dims); // `dims` and `num_dims` must point to arrays of length `num_shapes`. // Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, // `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` // must be >= -1, with -1 meaning "unknown dimension". -extern void TF_SetAttrShapeList(TF_OperationDescription* desc, - const char* attr_name, - const int64_t* const* dims, const int* num_dims, - int num_shapes); +TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* const* dims, + const int* num_dims, + int num_shapes); // `proto` must point to an array of `proto_len` bytes representing a // binary-serialized TensorShapeProto. -extern void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( + TF_OperationDescription* desc, const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status); // `protos` and `proto_lens` must point to arrays of length `num_shapes`. // `protos[i]` must point to an array of `proto_lens[i]` bytes // representing a binary-serialized TensorShapeProto. -extern void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc, - const char* attr_name, - const void* const* protos, - const size_t* proto_lens, - int num_shapes, TF_Status* status); - -extern void TF_SetAttrTensor(TF_OperationDescription* desc, - const char* attr_name, TF_Tensor* value, - TF_Status* status); -extern void TF_SetAttrTensorList(TF_OperationDescription* desc, - const char* attr_name, - TF_Tensor* const* values, int num_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( + TF_OperationDescription* desc, const char* attr_name, + const void* const* protos, const size_t* proto_lens, int num_shapes, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* value, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* const* values, + int num_values, + TF_Status* status); // `proto` should point to a sequence of bytes of length `proto_len` // representing a binary serialization of an AttrValue protocol // buffer. -extern void TF_SetAttrValueProto(TF_OperationDescription* desc, - const char* attr_name, const void* proto, - size_t proto_len, TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); // If this function succeeds: // * *status is set to an OK value, @@ -511,37 +558,38 @@ extern void TF_SetAttrValueProto(TF_OperationDescription* desc, // * the graph is not modified, // * a null value is returned. // In either case, it deletes `desc`. -extern TF_Operation* TF_FinishOperation(TF_OperationDescription* desc, - TF_Status* status); +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( + TF_OperationDescription* desc, TF_Status* status); // TF_Operation functions. Operations are immutable once created, so // these are all query functions. -extern const char* TF_OperationName(TF_Operation* oper); -extern const char* TF_OperationOpType(TF_Operation* oper); -extern const char* TF_OperationDevice(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); -extern int TF_OperationNumOutputs(TF_Operation* oper); -extern TF_DataType TF_OperationOutputType(TF_Output oper_out); -extern int TF_OperationOutputListLength(TF_Operation* oper, - const char* arg_name, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); -extern int TF_OperationNumInputs(TF_Operation* oper); -extern TF_DataType TF_OperationInputType(TF_Input oper_in); -extern int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name, - TF_Status* status); +TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); +TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); // In this code: // TF_Output producer = TF_OperationInput(consumer); // There is an edge from producer.oper's output (given by // producer.index) to consumer.oper's input (given by consumer.index). -extern TF_Output TF_OperationInput(TF_Input oper_in); +TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); // Get the number of current consumers of a specific output of an // operation. Note that this number can change when new operations // are added to the graph. -extern int TF_OperationOutputNumConsumers(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); // Get list of all current consumers of a specific output of an // operation. `consumers` must point to an array of length at least @@ -550,24 +598,24 @@ extern int TF_OperationOutputNumConsumers(TF_Output oper_out); // modification of the graph can increase the number of consumers of // an operation. Returns the number of output consumers (should match // TF_OperationOutputNumConsumers(oper_out)). -extern int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, - int max_consumers); +TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, + TF_Input* consumers, + int max_consumers); // Get the number of control inputs to an operation. -extern int TF_OperationNumControlInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); // Get list of all control inputs to an operation. `control_inputs` must // point to an array of length `max_control_inputs` (ideally set to // TF_OperationNumControlInputs(oper)). Returns the number of control // inputs (should match TF_OperationNumControlInputs(oper)). -extern int TF_OperationGetControlInputs(TF_Operation* oper, - TF_Operation** control_inputs, - int max_control_inputs); +TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( + TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); // Get the number of operations that have `*oper` as a control input. // Note that this number can change when new operations are added to // the graph. -extern int TF_OperationNumControlOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); // Get the list of operations that have `*oper` as a control input. // `control_outputs` must point to an array of length at least @@ -576,12 +624,12 @@ extern int TF_OperationNumControlOutputs(TF_Operation* oper); // modification of the graph can increase the number of control // outputs. Returns the number of control outputs (should match // TF_OperationNumControlOutputs(oper)). -extern int TF_OperationGetControlOutputs(TF_Operation* oper, - TF_Operation** control_outputs, - int max_control_outputs); +TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( + TF_Operation* oper, TF_Operation** control_outputs, + int max_control_outputs); // TF_AttrType describes the type of the value of an attribute on an operation. -typedef enum { +typedef enum TF_AttrType { TF_ATTR_STRING = 0, TF_ATTR_INT = 1, TF_ATTR_FLOAT = 2, @@ -625,17 +673,18 @@ typedef struct TF_AttrMetadata { } TF_AttrMetadata; // Returns metadata about the value of the attribute `attr_name` of `oper`. -extern TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, - const char* attr_name, - TF_Status* status); +TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( + TF_Operation* oper, const char* attr_name, TF_Status* status); // Fills in `value` with the value of the attribute `attr_name`. `value` must // point to an array of length at least `max_length` (ideally set to // TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, - void* value, size_t max_length, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, + const char* attr_name, + void* value, + size_t max_length, + TF_Status* status); // Get the list of strings in the value of the attribute `attr_name`. Fills in // `values` and `lengths`, each of which must point to an array of length at @@ -648,64 +697,78 @@ extern void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name, // attr_name). // // Fails if storage_size is too small to hold the requested number of strings. -extern void TF_OperationGetAttrStringList(TF_Operation* oper, - const char* attr_name, void** values, - size_t* lengths, int max_values, - void* storage, size_t storage_size, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( + TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, + int max_values, void* storage, size_t storage_size, TF_Status* status); -extern void TF_OperationGetAttrInt(TF_Operation* oper, const char* attr_name, - int64_t* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, + const char* attr_name, + int64_t* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrIntList(TF_Operation* oper, - const char* attr_name, int64_t* values, - int max_values, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, + const char* attr_name, + int64_t* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrFloat(TF_Operation* oper, const char* attr_name, - float* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, + const char* attr_name, + float* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrFloatList(TF_Operation* oper, - const char* attr_name, float* values, - int max_values, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, + const char* attr_name, + float* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrBool(TF_Operation* oper, const char* attr_name, - unsigned char* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, + const char* attr_name, + unsigned char* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrBoolList(TF_Operation* oper, - const char* attr_name, - unsigned char* values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, + const char* attr_name, + unsigned char* values, + int max_values, + TF_Status* status); -extern void TF_OperationGetAttrType(TF_Operation* oper, const char* attr_name, - TF_DataType* value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, + const char* attr_name, + TF_DataType* value, + TF_Status* status); // Fills in `values` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `max_values` (ideally set // to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, // attr_name)). -extern void TF_OperationGetAttrTypeList(TF_Operation* oper, - const char* attr_name, - TF_DataType* values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, + const char* attr_name, + TF_DataType* values, + int max_values, + TF_Status* status); // Fills in `value` with the value of the attribute `attr_name` of `oper`. // `values` must point to an array of length at least `num_dims` (ideally set to // TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). -extern void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, - int64_t* value, int num_dims, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, + const char* attr_name, + int64_t* value, + int num_dims, + TF_Status* status); // Fills in `dims` with the list of shapes in the attribute `attr_name` of // `oper` and `num_dims` with the corresponding number of dimensions. On return, @@ -720,35 +783,32 @@ extern void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name, // attr_name). // // Fails if storage_size is insufficient to hold the requested shapes. -extern void TF_OperationGetAttrShapeList(TF_Operation* oper, - const char* attr_name, int64_t** dims, - int* num_dims, int num_shapes, - int64_t* storage, int storage_size, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( + TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, + int num_shapes, int64_t* storage, int storage_size, TF_Status* status); // Sets `value` to the binary-serialized TensorShapeProto of the value of // `attr_name` attribute of `oper`'. -extern void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* value, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* value, + TF_Status* status); // Fills in `values` with binary-serialized TensorShapeProto values of the // attribute `attr_name` of `oper`. `values` must point to an array of length at // least `num_values` (ideally set to TF_AttrMetadata.list_size from // TF_OperationGetAttrMetadata(oper, attr_name)). -extern void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper, - const char* attr_name, - TF_Buffer** values, - int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( + TF_Operation* oper, const char* attr_name, TF_Buffer** values, + int max_values, TF_Status* status); // Gets the TF_Tensor valued attribute of `attr_name` of `oper`. // // Allocates a new TF_Tensor which the caller is expected to take // ownership of (and can deallocate using TF_DeleteTensor). -extern void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, - TF_Tensor** value, TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, + const char* attr_name, + TF_Tensor** value, + TF_Status* status); // Fills in `values` with the TF_Tensor values of the attribute `attr_name` of // `oper`. `values` must point to an array of TF_Tensor* of length at least @@ -757,22 +817,22 @@ extern void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, // // The caller takes ownership of all the non-null TF_Tensor* entries in `values` // (which can be deleted using TF_DeleteTensor(values[i])). -extern void TF_OperationGetAttrTensorList(TF_Operation* oper, - const char* attr_name, - TF_Tensor** values, int max_values, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, + const char* attr_name, + TF_Tensor** values, + int max_values, + TF_Status* status); // Sets `output_attr_value` to the binary-serialized AttrValue proto // representation of the value of the `attr_name` attr of `oper`. -extern void TF_OperationGetAttrValueProto(TF_Operation* oper, - const char* attr_name, - TF_Buffer* output_attr_value, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); // Returns the operation in the graph with `oper_name`. Returns nullptr if // no operation found. -extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph, - const char* oper_name); +TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( + TF_Graph* graph, const char* oper_name); // Iterate through the operations of a graph. To use: // size_t pos = 0; @@ -780,7 +840,8 @@ extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph, // while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { // DoSomethingWithOperation(oper); // } -extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos); +TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, + size_t* pos); // Write out a serialized representation of `graph` (as a GraphDef protocol // message) to `output_graph_def` (allocated by TF_NewBuffer()). @@ -788,25 +849,27 @@ extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos); // is called. // // May fail on very large graphs in the future. -extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, + TF_Buffer* output_graph_def, + TF_Status* status); // TF_ImportGraphDefOptions holds options that can be passed to // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; -extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); -extern void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts); +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions(); +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( + TF_ImportGraphDefOptions* opts); // Set the prefix to be prepended to the names of nodes in `graph_def` that will // be imported into `graph`. -extern void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, - const char* prefix); +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( + TF_ImportGraphDefOptions* opts, const char* prefix); // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. -extern void TF_ImportGraphDefOptionsAddInputMapping( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst); @@ -814,23 +877,23 @@ extern void TF_ImportGraphDefOptionsAddInputMapping( // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references an operation already existing in the graph being imported // into. -extern void TF_GraphImportGraphDefOptionsRemapControlDependency( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); // Cause the imported graph to have a control dependency on `oper`. `oper` // should exist in the graph being imported into. -extern void TF_ImportGraphDefOptionsAddControlDependency( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( TF_ImportGraphDefOptions* opts, TF_Operation* oper); // Add an output in `graph_def` to be returned via the `return_outputs` output // parameter of TF_GraphImportGraphDef(). If the output is remapped via an input // mapping, the corresponding existing tensor in `graph` will be returned. -extern void TF_ImportGraphDefOptionsAddReturnOutput( +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_ImportGraphDefOptions* opts, const char* oper_name, int index); // Returns the number of return outputs added via // TF_ImportGraphDefOptionsAddReturnOutput(). -extern int TF_ImportGraphDefOptionsNumReturnOutputs( +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); // Import the graph serialized in `graph_def` into `graph`. @@ -839,22 +902,22 @@ extern int TF_ImportGraphDefOptionsNumReturnOutputs( // result of TF_ImportGraphDefOptionsNumReturnOutputs()). If // `num_return_outputs` is non-zero, `return_outputs` must be of length // `num_return_outputs`. Otherwise it can be null. -extern void TF_GraphImportGraphDefWithReturnOutputs( +TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, int num_return_outputs, TF_Status* status); // Import the graph serialized in `graph_def` into `graph`. // Convenience function for when no return outputs have been added. -extern void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* options, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); // Note: The following function may fail on very large protos in the future. -extern void TF_OperationToNodeDef(TF_Operation* oper, - TF_Buffer* output_node_def, - TF_Status* status); +TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, + TF_Buffer* output_node_def, + TF_Status* status); typedef struct TF_WhileParams { // The number of inputs to the while loop, i.e. the number of loop variables. @@ -894,7 +957,7 @@ typedef struct TF_WhileParams { // TF_FinishWhile() or TF_AbortWhile(). // // Missing functionality (TODO): -// - Gradients (not yet implmented for any ops) +// - Gradients // - Reference-type inputs // - Directly referencing external tensors from the cond/body graphs (this is // possible in the Python API) @@ -917,7 +980,22 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, // called after a successful TF_NewWhile() call. void TF_AbortWhile(const TF_WhileParams* params); -// TODO(andydavis): Function to add gradients to a graph. +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, TF_Output* dy); // TODO(josh11b): Register OpDef, available to all operations added // to this graph. @@ -936,8 +1014,9 @@ typedef struct TF_Session TF_Session; // *graph must be a valid graph (not deleted or nullptr). This function will // prevent the graph from being deleted until TF_DeleteSession() is called. // Does not take ownership of opts. -extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, - TF_Status* status); +TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, + const TF_SessionOptions* opts, + TF_Status* status); // This function creates a new TF_Session (which is created on success) using // `session_options`, and then initializes state (restoring tensors and other @@ -962,7 +1041,7 @@ TF_Session* TF_LoadSessionFromSavedModel( // // Contacts any other processes associated with the session, if applicable. // May not be called after TF_DeleteSession(). -extern void TF_CloseSession(TF_Session*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); // Destroy a session object. // @@ -970,7 +1049,7 @@ extern void TF_CloseSession(TF_Session*, TF_Status* status); // local resources associated with the session. The session may not be used // during or after this call (and the session drops its reference to the // corresponding graph). -extern void TF_DeleteSession(TF_Session*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); // Run the graph associated with the session starting with the supplied inputs // (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). @@ -996,21 +1075,20 @@ extern void TF_DeleteSession(TF_Session*, TF_Status* status); // to the caller, which must eventually call TF_DeleteTensor on them. // // On failure, output_values[] contains NULLs. -extern void TF_SessionRun(TF_Session* session, - // RunOptions - const TF_Buffer* run_options, - // Input tensors - const TF_Output* inputs, - TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, - int noutputs, - // Target operations - const TF_Operation* const* target_opers, int ntargets, - // RunMetadata - TF_Buffer* run_metadata, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionRun( + TF_Session* session, + // RunOptions + const TF_Buffer* run_options, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // RunMetadata + TF_Buffer* run_metadata, + // Output status + TF_Status*); // Set up the graph with the intended feeds (inputs) and fetches (outputs) for a // sequence of partial run calls. @@ -1022,38 +1100,36 @@ extern void TF_SessionRun(TF_Session* session, // On failure, out_status contains a tensorflow::Status with an error // message. // NOTE: This is EXPERIMENTAL and subject to change. -extern void TF_SessionPRunSetup(TF_Session*, - // Input names - const TF_Output* inputs, int ninputs, - // Output names - const TF_Output* outputs, int noutputs, - // Target operations - const TF_Operation* const* target_opers, - int ntargets, - // Output handle - const char** handle, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionPRunSetup( + TF_Session*, + // Input names + const TF_Output* inputs, int ninputs, + // Output names + const TF_Output* outputs, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. // NOTE: This is EXPERIMENTAL and subject to change. -extern void TF_SessionPRun(TF_Session*, const char* handle, - // Input tensors - const TF_Output* inputs, - TF_Tensor* const* input_values, int ninputs, - // Output tensors - const TF_Output* outputs, TF_Tensor** output_values, - int noutputs, - // Target operations - const TF_Operation* const* target_opers, - int ntargets, - // Output status - TF_Status*); +TF_CAPI_EXPORT extern void TF_SessionPRun( + TF_Session*, const char* handle, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output status + TF_Status*); // Deletes a handle allocated by TF_SessionPRunSetup. // Once called, no more calls to TF_SessionPRun should be made. -extern void TF_DeletePRunHandle(const char* handle); +TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); // -------------------------------------------------------------------------- // The deprecated session API. Please switch to the above instead of @@ -1062,39 +1138,47 @@ extern void TF_DeletePRunHandle(const char* handle); typedef struct TF_DeprecatedSession TF_DeprecatedSession; -extern TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions*, +TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( + const TF_SessionOptions*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, TF_Status* status); -extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, TF_Status* status); -extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, - TF_Status* status); -extern void TF_Reset(const TF_SessionOptions* opt, const char** containers, - int ncontainers, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status); // Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and // add the nodes in that GraphDef to the graph for the session. // // Prefer use of TF_Session and TF_GraphImportGraphDef over this. -extern void TF_ExtendGraph(TF_DeprecatedSession*, const void* proto, - size_t proto_len, TF_Status*); +TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, + const void* proto, size_t proto_len, + TF_Status*); // See TF_SessionRun() above. -extern void TF_Run(TF_DeprecatedSession*, const TF_Buffer* run_options, - const char** input_names, TF_Tensor** inputs, int ninputs, - const char** output_names, TF_Tensor** outputs, int noutputs, - const char** target_oper_names, int ntargets, - TF_Buffer* run_metadata, TF_Status*); +TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, + const TF_Buffer* run_options, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status*); // See TF_SessionPRunSetup() above. -extern void TF_PRunSetup(TF_DeprecatedSession*, const char** input_names, - int ninputs, const char** output_names, int noutputs, - const char** target_oper_names, int ntargets, - const char** handle, TF_Status*); +TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, + const char** input_names, int ninputs, + const char** output_names, int noutputs, + const char** target_oper_names, + int ntargets, const char** handle, + TF_Status*); // See TF_SessionPRun above. -extern void TF_PRun(TF_DeprecatedSession*, const char* handle, - const char** input_names, TF_Tensor** inputs, int ninputs, - const char** output_names, TF_Tensor** outputs, - int noutputs, const char** target_oper_names, int ntargets, - TF_Status*); +TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Status*); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels @@ -1113,19 +1197,19 @@ typedef struct TF_Library TF_Library; // The caller owns the library handle. // // On failure, place an error status in status and return NULL. -extern TF_Library* TF_LoadLibrary(const char* library_filename, - TF_Status* status); +TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); // Get the OpList of OpDefs defined in the library pointed by lib_handle. // // Returns a TF_Buffer. The memory pointed to by the result is owned by // lib_handle. The data in the buffer will be the serialized OpList proto for // ops defined in the library. -extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); +TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); // Frees the memory associated with the library handle. // Does NOT unload the library. -extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); +TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // Get the OpList of all OpDefs defined in this address space. // Returns a TF_Buffer, ownership of which is transferred to the caller @@ -1133,7 +1217,7 @@ extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // // The data in the buffer will be the serialized OpList proto for ops registered // in this address space. -extern TF_Buffer* TF_GetAllOpList(); +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..b5320d20dadb0f466b8b29b8ba5eda1693e0faba --- /dev/null +++ b/tensorflow/c/c_api_internal.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" + + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TF_Status { + tensorflow::Status status; +}; + +struct TF_Tensor { + TF_DataType dtype; + tensorflow::TensorShape shape; + tensorflow::TensorBuffer* buffer; +}; + +struct TF_SessionOptions { + tensorflow::SessionOptions options; +}; + +struct TF_DeprecatedSession { + tensorflow::Session* session; +}; + +struct TF_Library { + void* lib_handle; + TF_Buffer op_list; +}; + +struct TF_Graph { + TF_Graph() + : graph(tensorflow::OpRegistry::Global()), + refiner(graph.versions().producer(), graph.op_registry()), + num_sessions(0), + delete_requested(false), + parent(nullptr), + parent_inputs(nullptr) {} + tensorflow::mutex mu; + tensorflow::Graph graph GUARDED_BY(mu); + + // Runs shape inference. + tensorflow::ShapeRefiner refiner GUARDED_BY(mu); + + // Maps from name of an operation to the Node* in 'graph'. + std::unordered_map name_map + GUARDED_BY(mu); + + // TF_Graph may only / must be deleted when + // num_sessions == 0 && delete_requested == true + + // num_sessions incremented by TF_NewSession, and decremented by + // TF_DeleteSession. + int num_sessions GUARDED_BY(mu); + bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph + + // Used to link graphs contained in TF_WhileParams to the parent graph that + // will eventually contain the full while loop. + TF_Graph* parent; + TF_Output* parent_inputs; +}; + +struct TF_OperationDescription { + TF_OperationDescription(TF_Graph* g, const char* op_type, + const char* node_name) + : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} + + tensorflow::NodeBuilder node_builder; + TF_Graph* graph; + std::vector colocation_constraints; +}; + +struct TF_Operation { + tensorflow::Node node; +}; + +struct TF_Session { + TF_Session(tensorflow::Session* s, TF_Graph* g) + : session(s), graph(g), last_num_graph_nodes(0) {} + tensorflow::Session* session; + TF_Graph* graph; + tensorflow::mutex mu; + int last_num_graph_nodes; +}; + +struct TF_ImportGraphDefOptions { + tensorflow::ImportGraphDefOptions opts; +}; diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 5673f657d3c5b77618c481da614573b9e4a63aba..cdb7406c86e8b10d24c303615d13089272bcab5d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" @@ -38,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/util/equal_graph_def.h" using tensorflow::int32; using tensorflow::string; @@ -105,6 +107,22 @@ TEST(CAPI, AllocateTensor) { TF_DeleteTensor(t); } +TEST(CAPI, MaybeMove) { + const int num_bytes = 6 * sizeof(float); + float* values = + reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( + EIGEN_MAX_ALIGN_BYTES, num_bytes)); + int64_t dims[] = {2, 3}; + bool deallocator_called = false; + TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, + &Deallocator, &deallocator_called); + + TF_Tensor* o = TF_TensorMaybeMove(t); + ASSERT_TRUE(o == nullptr); // It is unsafe to move memory TF might not own. + TF_DeleteTensor(t); + EXPECT_TRUE(deallocator_called); +} + TEST(CAPI, LibraryLoadFunctions) { // Load the library. TF_Status* status = TF_NewStatus(); @@ -261,6 +279,19 @@ static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +// Create a tensor with values of type TF_INT8 provided by `values`. +static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, + const char* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); + memcpy(TF_TensorData(t), values, sizeof(char) * num_values); + return t; +} + static TF_Tensor* Int32Tensor(int32 v) { const int num_bytes = sizeof(int32); int32* values = new int32[1]; @@ -276,16 +307,21 @@ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, return TF_FinishOperation(desc, s); } -TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, - const char* name = "scalar") { - unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name = "const") { TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); - TF_SetAttrTensor(desc, "value", tensor.get(), s); + TF_SetAttrTensor(desc, "value", t, s); if (TF_GetCode(s) != TF_OK) return nullptr; - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", TF_TensorType(t)); return TF_FinishOperation(desc, s); } +TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar") { + unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add") { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); @@ -805,6 +841,33 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(feed, control_inputs[0]); EXPECT_EQ(feed2, control_inputs[1]); + // Export to a graph def so we can import a graph with control dependencies + TF_DeleteBuffer(graph_def); + graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import again, with remapped control dependency, into the same graph + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); + TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar4 = + TF_GraphOperationByName(graph, "imported4/imported3/scalar"); + TF_Operation* feed4 = + TF_GraphOperationByName(graph, "imported4/imported2/feed"); + + // Check that imported `imported3/scalar` has remapped control dep from + // original graph and imported control dep + num_control_inputs = TF_OperationGetControlInputs( + scalar4, control_inputs, TF_OperationNumControlInputs(scalar4)); + ASSERT_EQ(2, num_control_inputs); + EXPECT_EQ(feed, control_inputs[0]); + EXPECT_EQ(feed4, control_inputs[1]); + TF_DeleteImportGraphDefOptions(opts); TF_DeleteBuffer(graph_def); @@ -1049,6 +1112,35 @@ TEST(CAPI, SessionPRun) { TF_DeleteStatus(s); } +TEST(CAPI, ShapeInferenceError) { + // TF_FinishOperation should fail if the shape of the added operation cannot + // be inferred. + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create this failure by trying to add two nodes with incompatible shapes + // (A tensor with shape [2] and a tensor with shape [3] cannot be added). + const char data[] = {1, 2, 3}; + const int64_t vec2_dims[] = {2}; + unique_tensor_ptr vec2_tensor( + Int8Tensor(vec2_dims, TF_ARRAYSIZE(vec2_dims), data), TF_DeleteTensor); + TF_Operation* vec2 = Const(vec2_tensor.get(), graph, status, "vec2"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const int64_t vec3_dims[] = {3}; + unique_tensor_ptr vec3_tensor( + Int8Tensor(vec3_dims, TF_ARRAYSIZE(vec3_dims), data), TF_DeleteTensor); + TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3"); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Operation* add = Add(vec2, vec3, graph, status); + ASSERT_NE(TF_OK, TF_GetCode(status)); + ASSERT_TRUE(add == nullptr); + + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + TEST(CAPI, ColocateWith) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1482,16 +1574,280 @@ TEST_F(CApiWhileLoopTest, BadTypes) { TF_AbortWhile(params_.get()); } -// Create a tensor with values of type TF_INT8 provided by `values`. -TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { - int64_t num_values = 1; - for (int i = 0; i < num_dims; ++i) { - num_values *= dims[i]; +REGISTER_OP("TestOpWithNoGradient") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Test op with no grad registered. + +x: input +y: output +)doc") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); + +class CApiGradientsTest : public ::testing::Test { + protected: + CApiGradientsTest() + : s_(TF_NewStatus()), + graph_(TF_NewGraph()), + expected_graph_(TF_NewGraph()) {} + + ~CApiGradientsTest() override { + TF_DeleteGraph(graph_); + TF_DeleteGraph(expected_graph_); + TF_DeleteStatus(s_); } - TF_Tensor* t = - TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values); - memcpy(TF_TensorData(t), values, sizeof(char) * num_values); - return t; + + void TestGradientsSuccess(bool grad_inputs_provided) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + TF_Output expected_grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); + + AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); + + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Compare that the graphs match. + GraphDef expected_gdef; + GraphDef gdef; + EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef)); + EXPECT_TRUE(GetGraphDef(graph_, &gdef)); + TF_EXPECT_GRAPH_EQ(expected_gdef, gdef); + + // Compare that the output of the gradients of both graphs match. + RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs); + } + + void TestGradientsError(bool grad_inputs_provided) { + TF_Output inputs[1]; + TF_Output outputs[1]; + TF_Output grad_outputs[1]; + + BuildErrorGraph(inputs, outputs); + + AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + + string expected_msg = + "No gradient defined for op: TestOpWithNoGradient. Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + EXPECT_EQ(expected_msg, TF_Message(s_)); + } + + // Run the graph and ensure that the gradient values are as expected. + void RunGraphsAndCompareOutputs(TF_Output* grad_outputs, + TF_Output* expected_grad_outputs) { + std::unique_ptr csession(new CSession(graph_, s_)); + std::unique_ptr expected_csession( + new CSession(expected_graph_, s_)); + + std::vector grad_outputs_vec; + grad_outputs_vec.assign(grad_outputs, grad_outputs + 2); + csession->SetOutputs(grad_outputs_vec); + csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out0 = csession->output_tensor(0); + TF_Tensor* out1 = csession->output_tensor(1); + + std::vector expected_grad_outputs_vec; + expected_grad_outputs_vec.assign(expected_grad_outputs, + expected_grad_outputs + 2); + expected_csession->SetOutputs(expected_grad_outputs_vec); + expected_csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* expected_out0 = expected_csession->output_tensor(0); + TF_Tensor* expected_out1 = expected_csession->output_tensor(1); + + CompareTensors(out0, expected_out0); + CompareTensors(out1, expected_out1); + } + + void CompareTensors(TF_Tensor* a, TF_Tensor* b) { + float* a_data = static_cast(TF_TensorData(a)); + float* b_data = static_cast(TF_TensorData(b)); + EXPECT_EQ(*a_data, *b_data); + } + + void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, + TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + if (grad_inputs_provided) { + TF_Output grad_inputs[1]; + const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; + TF_Operation* grad_inputs_op = + FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); + grad_inputs[0] = TF_Output{grad_inputs_op, 0}; + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, + s_, grad_outputs); + } else { + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, + grad_outputs); + } + } + + void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) { + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad"); + inputs[0] = TF_Output{const0, 0}; + outputs[0] = TF_Output{nograd, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) { + // Construct the following graph: + // | + // z| + // | + // MatMul + // / \ + // ^ ^ + // | | + // x| y| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul"); + inputs[0] = TF_Output{const0, 0}; + inputs[1] = TF_Output{const1, 0}; + outputs[0] = TF_Output{matmul, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildExpectedGraph(bool grad_inputs_provided, + TF_Output* expected_grad_outputs) { + // The expected graph looks like this if grad_inputs_provided. + // If grad_inputs_provided is false, Const_0 will be a OnesLike op. + // ^ ^ + // dy| dx| // MatMul Gradient Graph + // | | + // MatMul_2 MatMul_1 + // ^ ^ ^ ^ + // | |----------| | + // | ^ | + // | dz| | + // | | | + // | Const_3 | + // | | + // | ^ | + // | z| | // MatMul Forward Graph + // | | | + // | MatMul | + // | / \ | + // | ^ ^ | + // | | | | + // |---x| y|----| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = + FloatConst2x2(expected_graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = + FloatConst2x2(expected_graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = + MatMul(expected_graph_, s_, const0, const1, "MatMul"); + + TF_Operation* const3; + if (grad_inputs_provided) { + const float const3_val[] = {1.0, 1.0, 1.0, 1.0}; + const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); + } else { + const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike"); + } + + TF_Operation* matmul1 = + MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true); + TF_Operation* matmul2 = + MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false); + expected_grad_outputs[0] = {matmul1, 0}; + expected_grad_outputs[1] = {matmul2, 0}; + } + + TF_Tensor* FloatTensor2x2(const float* values) { + const int64_t dims[2] = {2, 2}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); + memcpy(TF_TensorData(t), values, sizeof(float) * 4); + return t; + } + + TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, + const float* values, const char* name) { + unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor); + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", tensor.get(), s); + if (TF_GetCode(s) != TF_OK) return nullptr; + TF_SetAttrType(desc, "dtype", TF_FLOAT); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, + TF_Operation* r, const char* name, + bool transpose_a = false, bool transpose_b = false) { + TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name); + if (transpose_a) { + TF_SetAttrBool(desc, "transpose_a", 1); + } + if (transpose_b) { + TF_SetAttrBool(desc, "transpose_b", 1); + } + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "TestOpWithNoGradient", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Status* s_; + TF_Graph* graph_; + TF_Graph* expected_graph_; +}; + +TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); } + +TEST_F(CApiGradientsTest, Gradients_NoGradInputs) { + TestGradientsSuccess(false); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) { + TestGradientsError(true); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { + TestGradientsError(false); } void StringVectorToArrays(const std::vector& v, @@ -1509,9 +1865,13 @@ void StringVectorToArrays(const std::vector& v, // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other // will have list(type). -#define ATTR_TEST_REGISTER_OP(type) \ - REGISTER_OP("CApiAttributesTestOp" #type).Attr("v: " #type); \ - REGISTER_OP("CApiAttributesTestOpList" #type).Attr("v: list(" #type ")") +#define ATTR_TEST_REGISTER_OP(type) \ + REGISTER_OP("CApiAttributesTestOp" #type) \ + .Attr("v: " #type) \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ + REGISTER_OP("CApiAttributesTestOpList" #type) \ + .Attr("v: list(" #type ")") \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape) ATTR_TEST_REGISTER_OP(string); ATTR_TEST_REGISTER_OP(int); ATTR_TEST_REGISTER_OP(float); diff --git a/tensorflow/c/exported_symbols.lds b/tensorflow/c/exported_symbols.lds new file mode 100644 index 0000000000000000000000000000000000000000..a14bdaa48be55641a652795e2677b16e86918c11 --- /dev/null +++ b/tensorflow/c/exported_symbols.lds @@ -0,0 +1 @@ +_TF_* diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh new file mode 100755 index 0000000000000000000000000000000000000000..73d427d9b2280123f9d54cdd7e4f9a76a7dddad1 --- /dev/null +++ b/tensorflow/c/generate-pc.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +TF_PREFIX='/usr/local' + +usage() { + echo "Usage: $0 OPTIONS" + echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-v, --version\tset TensorFlow version" + echo -e "-h, --help\tdisplay this message" +} + +[ $# == 0 ] && usage && exit 0 + +# read the options +ARGS=`getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@"` +eval set -- "$ARGS" + +# extract options and their arguments into variables. +while true ; do + case "$1" in + -h|--help) usage ; exit ;; + -p|--prefix) + case "$2" in + "") shift 2 ;; + *) TF_PREFIX=$2 ; shift 2 ;; + esac ;; + -v|--version) + case "$2" in + "") shift 2 ;; + *) TF_VERSION=$2 ; shift 2 ;; + esac ;; + --) shift ; break ;; + *) echo "Internal error! Try '$0 --help' for more information." ; exit 1 ;; + esac +done + +[ -z $TF_VERSION ] && echo "Specify a version using -v or --version" && exit 1 + +echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" + +cat << EOF > tensorflow.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/lib +includedir=\${prefix}/include + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow +Cflags: -I\${includedir} +EOF diff --git a/tensorflow/c/version_script.lds b/tensorflow/c/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..455bd7362bb36d30af421a17f0e2f8e9ba66e02b --- /dev/null +++ b/tensorflow/c/version_script.lds @@ -0,0 +1,9 @@ +VERS_1.0 { + # Export symbols in c_api.h. + global: + TF_*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index aaebdded9a5b3232bec824b0768a536e36349204..8d4260a0b9ca38593a912398e8460d826fb31ccf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -91,6 +91,7 @@ cc_library( deps = [ ":array_grad", ":math_grad", + ":nn_grad", ], ) @@ -122,7 +123,10 @@ cc_library_with_android_deps( cc_library_with_android_deps( name = "scope", - srcs = ["framework/scope.cc"], + srcs = [ + "framework/scope.cc", + "framework/scope_internal.h", + ], hdrs = ["framework/scope.h"], android_deps = ["//tensorflow/core:android_tensorflow_lib"], common_deps = [ @@ -136,6 +140,15 @@ cc_library_with_android_deps( ], ) +cc_library_with_android_deps( + name = "scope_internal", + hdrs = ["framework/scope_internal.h"], + common_deps = [ + ":scope", + ], + deps = [], +) + tf_cc_test( name = "framework_scope_test", srcs = ["framework/scope_test.cc"], @@ -376,6 +389,16 @@ tf_gen_op_wrappers_cc( visibility = ["//tensorflow:internal"], ) +tf_gen_op_wrappers_cc( + name = "functional_ops", + include_internal_ops = 1, + op_lib_names = [ + "functional_ops", + ], + pkg = "//tensorflow/core", + visibility = ["//tensorflow:internal"], +) + tf_gen_op_wrappers_cc( name = "resource_variable_ops", include_internal_ops = 1, diff --git a/tensorflow/cc/client/client_session.cc b/tensorflow/cc/client/client_session.cc index 2732f3f5010d7522a1cf8631183e9b4df7ac86d8..2879445441d0a80c1320a30976412b416feaecc9 100644 --- a/tensorflow/cc/client/client_session.cc +++ b/tensorflow/cc/client/client_session.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" #include +#include #include #include "tensorflow/core/platform/env.h" @@ -31,7 +32,7 @@ class ClientSession::Impl { friend class ClientSession; Impl(Session* session, std::shared_ptr graph) - : session_(session), graph_(graph) {} + : session_(session), graph_(std::move(graph)) {} static SessionOptions MakeDefaultSessionOptions(const string& target); Status MaybeExtendGraph() const; diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index 9c0f00f2b128c7d06bc7c0ca8579d4ca8e530fe8..dfbac9788e16e9c7c65abcd1ea213b51d5d5d060 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -49,7 +49,7 @@ TEST(ClientSessionTest, Feed) { TEST(ClientSessionTest, Extend) { Scope root = Scope::NewRootScope(); - auto a = Placeholder(root, DT_INT32); + auto a = Placeholder(root, DT_INT32, Placeholder::Shape({2})); auto c = Add(root, a, {2, 2}); ClientSession session(root); std::vector outputs; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 22cd7fb0d438db9d9f7f29f5386c7a9722afe43d..71aa986f918de68822d457422f6c7a73d6253819 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -126,7 +126,11 @@ string PrintString(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } -string PrintTensorShape(const TensorShape& shape) { +string PrintTensorShape(const TensorShapeProto& shape_proto) { + PartialTensorShape shape(shape_proto); + if (shape.IsIdenticalTo(PartialTensorShape())) { + return "::tensorflow::PartialTensorShape() /* unknown */"; + } string ret = "{"; for (int d = 0; d < shape.dims(); ++d) { if (d > 0) strings::StrAppend(&ret, ", "); @@ -188,7 +192,13 @@ string PrintTensor(const TensorProto& tensor_proto) { } } -string PrintAttrValue(string op, const AttrValue& attr_value) { +string PrintTensorProto(const TensorProto& proto) { + return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ", + PrintTensorShape(proto.tensor_shape()), + ").AsTensorProto()"); +} + +string PrintAttrValue(const string& op, const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: return PrintString(attr_value.s()); @@ -203,12 +213,9 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { case AttrValue::kType: return EnumName_DataType(attr_value.type()); case AttrValue::kShape: - return PrintTensorShape(TensorShape(attr_value.shape())); + return PrintTensorShape(attr_value.shape()); case AttrValue::kTensor: - return strings::StrCat( - "Input::Initializer(", "{", PrintTensor(attr_value.tensor()), "}, ", - PrintTensorShape(TensorShape(attr_value.tensor().tensor_shape())), - ").AsTensorProto()"); + return PrintTensorProto(attr_value.tensor()); case AttrValue::kList: { string ret = "{"; if (attr_value.list().s_size() > 0) { @@ -241,8 +248,14 @@ string PrintAttrValue(string op, const AttrValue& attr_value) { } else if (attr_value.list().shape_size() > 0) { for (int i = 0; i < attr_value.list().shape_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend( - &ret, PrintTensorShape(TensorShape(attr_value.list().shape(i)))); + strings::StrAppend(&ret, + PrintTensorShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + PrintTensorProto(attr_value.list().tensor(i))); } } strings::StrAppend(&ret, "}"); @@ -292,8 +305,8 @@ std::pair AttrTypeName(StringPiece attr_type) { {"list(bool)", {"gtl::ArraySlice", true}}, {"type", {"DataType", false}}, {"list(type)", {"DataTypeSlice", true}}, - {"shape", {"TensorShape", false}}, - {"list(shape)", {"gtl::ArraySlice", true}}, + {"shape", {"PartialTensorShape", false}}, + {"list(shape)", {"gtl::ArraySlice", true}}, {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, @@ -717,7 +730,7 @@ void OpInfo::GetOutput(string* out) const { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, - " for (int64 i = 0; i < ret->num_outputs(); ++i)\n"); + " for (int32 i = 0; i < ret->num_outputs(); ++i)\n"); strings::StrAppend(out, " this->", output_names[0], ".push_back(Output(ret, i));\n"); } else { @@ -727,11 +740,10 @@ void OpInfo::GetOutput(string* out) const { return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); - strings::StrAppend( - out, - " ::tensorflow::Status _status_ = " - "::tensorflow::NameRangesForNode(ret->def(), ret->op_def(), " - "nullptr, &_outputs_range);\n"); + strings::StrAppend(out, + " ::tensorflow::Status _status_ = " + "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " + "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); @@ -740,7 +752,7 @@ void OpInfo::GetOutput(string* out) const { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { - strings::StrAppend(out, " for (int64 i = ", arg_range, ".first; i < ", + strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ", arg_range, ".second; ++i)\n"); strings::StrAppend(out, " this->", output_names[i], ".push_back(Output(ret, i));\n"); diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 6dc0d84c16d5b534341575b384997cc398c80bec..5da23036eaadbef270ba839357dc4613bf3bf490 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -32,10 +32,11 @@ Output Linear(const Scope& scope, Input x, Input w, Input b) { return BiasAdd(cop_scopes.last, m, b); } -void GetColocationConstraints(Output tensor, std::vector* constraints) { +void GetColocationConstraints(const Output& tensor, + std::vector* constraints) { constraints->clear(); - TF_EXPECT_OK( - GetNodeAttr(tensor.op().node()->def(), kColocationAttrName, constraints)); + TF_EXPECT_OK(GetNodeAttr(tensor.op().node()->attrs(), kColocationAttrName, + constraints)); } } // namespace @@ -158,11 +159,11 @@ TEST(CCOpTest, KernelLabel) { Scope root = Scope::NewRootScope(); auto add = Add(root.WithKernelLabel("AddWithKernelLabel"), 1.0f, 2.0f); TF_EXPECT_OK(root.status()); - const auto& attrs = add.z.op().node()->def().attr(); - ASSERT_TRUE(attrs.find("_kernel") != attrs.end()); - auto kernel_attr = attrs.find("_kernel")->second; - TF_EXPECT_OK(AttrValueHasType(kernel_attr, "string")); - EXPECT_EQ(kernel_attr.s(), "AddWithKernelLabel"); + AttrSlice attrs = add.z.op().node()->attrs(); + const auto* kernel_attr = attrs.Find("_kernel"); + ASSERT_TRUE(kernel_attr); + TF_EXPECT_OK(AttrValueHasType(*kernel_attr, "string")); + EXPECT_EQ(kernel_attr->s(), "AddWithKernelLabel"); } TEST(CCOpTest, ColocateWith) { @@ -189,8 +190,7 @@ TEST(CCOpTest, ColocateWith) { Scope with_colocate = root.ColocateWith(c3).ColocateWith(c4); auto c6 = Const(with_colocate.WithOpName("c6").ClearColocation(), 7); - const auto& attrs = c6.op().node()->def().attr(); - EXPECT_TRUE(attrs.find("_class") == attrs.end()); + EXPECT_FALSE(c6.op().node()->attrs().Find("_class")); } TEST(CCOpTest, TemplatedConst) { diff --git a/tensorflow/cc/framework/grad_op_registry.cc b/tensorflow/cc/framework/grad_op_registry.cc index 0d6a377b507161c4420a6076b9ee71e799e0223b..254705736e7711e58aa87054f36c8a19eebd4f0d 100644 --- a/tensorflow/cc/framework/grad_op_registry.cc +++ b/tensorflow/cc/framework/grad_op_registry.cc @@ -32,7 +32,13 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) { Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { auto iter = registry_.find(op); if (iter == registry_.end()) { - return errors::NotFound("No gradient defined for op: ", op); + const string error_msg = + "No gradient defined for op: " + op + + ". Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + return errors::NotFound(error_msg); } *func = iter->second; return Status::OK(); diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 849a8eed6f23fb8dd1290d1bfa9db9c47d5d9f9d..8f20ff1457b219da3f11d9ffdafdd470875b25b0 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -40,8 +40,8 @@ Status ComputeTheoreticalJacobianTranspose( const std::vector& x_datas, const OutputList& ys, const std::vector& y_shapes, std::vector& jacobian_ts) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); // Call AddSymbolicGradients to get 'dxs' (we will feed 'dys'). OutputList dys; for (const auto& y_shape : y_shapes) { @@ -130,8 +130,8 @@ Status ComputeNumericJacobianTranspose(const Scope& scope, const OutputList& xs, const T delta, std::vector& x_datas, std::vector& jacobian_ts) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); ClientSession session(scope); for (int x_idx = 0; x_idx < x_num; x_idx++) { @@ -176,8 +176,8 @@ void InitJacobians(const OutputList& xs, const std::vector& x_shapes, const std::vector& y_shapes, std::vector& jacobians) { - int y_num = y_shapes.size(); - int x_num = x_shapes.size(); + size_t y_num = y_shapes.size(); + size_t x_num = x_shapes.size(); jacobians.resize(y_num * x_num); for (int x_idx = 0; x_idx < x_num; x_idx++) { diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 2c60f947a55479e27937b98de91d80b559d32576..8c00a6f70497df2c70f266a747197e50c98375bb 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -210,8 +210,8 @@ Status SymbolicGradientBuilder::Initialize() { { // Initialize backprop with `grad_inputs_`. - const int num_dy = grad_inputs_.size(); - for (int i = 0; i < num_dy; ++i) { + const size_t num_dy = grad_inputs_.size(); + for (size_t i = 0; i < num_dy; ++i) { TF_RETURN_IF_ERROR(BackpropAlongEdge(grad_inputs_[i], outputs_[i])); } } @@ -308,7 +308,7 @@ Status SymbolicGradientBuilder::AddGradients() { continue; } - const int num_no_grad = no_grad_dy_indices.size(); + const size_t num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': // Backprop 'NoGradient' along the in edges. @@ -367,6 +367,19 @@ Status AddSymbolicGradients(const Scope& scope, return builder.AddGradients(); } +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs) { + std::vector grad_inputs; + grad_inputs.reserve(outputs.size()); + for (const Output& output : outputs) { + grad_inputs.emplace_back(ops::OnesLike(scope, output)); + } + return AddSymbolicGradients(scope, outputs, inputs, grad_inputs, + grad_outputs); +} + Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); } } // end namespace tensorflow diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index d076bc43b4fbb1c8911b52c5ab258b7e9837113b..717f6f0636d3dd1a546ef7477b100bbfc86ba13d 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -27,16 +27,19 @@ namespace tensorflow { /// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes /// to the graph associated with 'scope', which compute (and return in /// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'. -/// - -// TODO(andydavis) Add overload of this function with no 'grad_inputs' arg. -// Implementation will fill in 'OnesLike' for all shapes in 'outputs'. Status AddSymbolicGradients(const Scope& scope, const std::vector& outputs, const std::vector& inputs, const std::vector& grad_inputs, std::vector* grad_outputs); +// Same as above, but uses 'OnesLike' for all shapes in +// 'outputs' as grad_inputs. +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs); + /// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient /// flows along some graph edge during backpropagation). /// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients' diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 6c2c2fcd1e2c5941dadebfbc78fb5bae9122e7c3..6a249825812b4d39b55f7170a35436b6ae88c020 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -40,7 +40,7 @@ class GradientsTest : public ::testing::Test { TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test)); GraphDef gdef_exp; TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp)); - TF_EXPECT_GRAPH_EQ(gdef_test, gdef_exp); + TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test); } Scope scope_expected_; @@ -98,6 +98,32 @@ TEST_F(GradientsTest, OneMatMul) { CompareTestAndExpectedGraphs(); } +TEST_F(GradientsTest, OneMatMul_InferGradInputs) { + for (const bool expected : {false, true}) { + const Scope& scope = expected ? scope_expected_ : scope_test_; + // Construct forward graph. + auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}}); + auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}}); + auto z = MatMul(scope, x, y); + TF_ASSERT_OK(scope.status()); + CHECK_NOTNULL(z.node()); + + if (expected) { + // Construct backward graph. + // The gradients function adds a OnesLike to create a dz of ones with the + // shape of z. + auto dz = OnesLike(scope, z); + auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true)); + auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true)); + } else { + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs)); + } + } + CompareTestAndExpectedGraphs(); +} + TEST_F(GradientsTest, TwoMatMuls_Chained) { for (const bool expected : {false, true}) { const Scope& scope = expected ? scope_expected_ : scope_test_; @@ -234,7 +260,7 @@ TEST_F(GradientsTest, StackUnstack_StopBackprop) { } TEST_F(GradientsTest, DependentGradOutputs) { - // Tests that dependant gradients (in this case the gradients w.r.t to the + // Tests that dependent gradients (in this case the gradients w.r.t to the // output and one input of MatMul) are computed properly. // Create two chained MatMul ops. diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc index 50df891a4c434ad58e962d7a31599df08cedaeb7..920a8e7955631ba0d33d2d36506703e107420a69 100644 --- a/tensorflow/cc/framework/ops.cc +++ b/tensorflow/cc/framework/ops.cc @@ -20,7 +20,7 @@ namespace tensorflow { Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {} -Output Operation::input(int i) const { +Output Operation::input(int32 i) const { CHECK_NOTNULL(node_); CHECK_GE(i, 0); CHECK_LT(i, node_->num_inputs()); @@ -37,14 +37,14 @@ Output Operation::input(int i) const { return Output(inputs_[i].first, inputs_[i].second); } -Output Operation::output(int i) const { +Output Operation::output(int32 i) const { CHECK_NOTNULL(node_); CHECK_GE(i, 0); CHECK_LT(i, node_->num_outputs()); return Output(node_, i); } -uint64 Operation::hash(int64 index) const { +uint64 Operation::hash(int32 index) const { return ::tensorflow::Hash64(reinterpret_cast(&node_), sizeof(Node*), index); } diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 889d5db31dd06fd25b7a72e209a8d7f37b8429ca..8d4154220c4b18f9286094b10c1b1e96eb4e31e7 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -39,22 +39,22 @@ class Operation { Operation() : node_(nullptr) {} explicit Operation(Node* n); - int num_inputs() const { return node_->num_inputs(); } - DataType input_type(int o) const { return node_->input_type(o); } - Output input(int i) const; + int32 num_inputs() const { return node_->num_inputs(); } + DataType input_type(int32 o) const { return node_->input_type(o); } + Output input(int32 i) const; - int num_outputs() const { return node_->num_outputs(); } - DataType output_type(int o) const { return node_->output_type(o); } - Output output(int i) const; + int32 num_outputs() const { return node_->num_outputs(); } + DataType output_type(int32 o) const { return node_->output_type(o); } + Output output(int32 i) const; Node* node() const { return node_; } - uint64 hash(int64 index) const; + uint64 hash(int32 index) const; bool operator==(const Operation& other) const { return node_ == other.node_; } private: - typedef std::vector> Inputs; + typedef std::vector> Inputs; static Inputs GetInputs(Node* node); Inputs inputs_; @@ -66,12 +66,12 @@ class Output { public: Output() = default; explicit Output(Node* n) : op_(n) {} - Output(Node* n, int64 index) : op_(n), index_(index) {} - Output(const Operation& op, int64 index) : op_(op), index_(index) {} + Output(Node* n, int32 index) : op_(n), index_(index) {} + Output(const Operation& op, int32 index) : op_(op), index_(index) {} Operation op() const { return op_; } Node* node() const { return op().node(); } - int64 index() const { return index_; } + int32 index() const { return index_; } DataType type() const { return op_.output_type(index_); } string name() const { return strings::StrCat(node()->name(), ":", index()); } bool operator==(const Output& other) const { @@ -82,14 +82,14 @@ class Output { private: Operation op_ = Operation(nullptr); - int64 index_ = 0; + int32 index_ = 0; }; /// Hash class that can be used for e.g. storing Outputs in an unordered_map struct OutputHash { std::size_t operator()(const Output& output) const { return Hash64Combine(std::hash()(output.node()), - std::hash()(output.index())); + std::hash()(output.index())); } }; @@ -230,12 +230,12 @@ class Input { /// Constructor specifying a node name, index and datatype. This should only /// be used for specifying a backward edge, needed by control flow. - Input(const string& name, int i, DataType dt) + Input(const string& name, int32 i, DataType dt) : node_name_(name), index_(i), data_type_(dt) {} Node* node() const { return output_.node(); } string node_name() const { return node_name_; } - int index() const { return node_name_.empty() ? output_.index() : index_; } + int32 index() const { return node_name_.empty() ? output_.index() : index_; } DataType data_type() const { return data_type_; } Status status() const { return status_; } const Tensor& tensor() const { return tensor_; } @@ -245,7 +245,7 @@ class Input { Output output_ = Output(Operation(nullptr), 0); Tensor tensor_; const string node_name_ = ""; - int index_ = 0; + int32 index_ = 0; DataType data_type_ = DT_INVALID; }; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 571c6e1e579f630db473ffc1312d1a1f3162f475..32c0822de69da7989ceaa4028539db928b6fcea3 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -25,6 +25,20 @@ limitations under the License. namespace tensorflow { class Scope::Impl { + public: + // A NameMap is used to keep track of suffixes for names used in a scope. A + // name that has not been used so far in a scope will get no suffix. Later + // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes + // can share the same NameMap. For instance, a new scope created using + // WithControlDependencies() should would share the same NameMap with the + // parent. + typedef std::unordered_map NameMap; + + Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner); + private: friend class Scope; @@ -40,14 +54,6 @@ class Scope::Impl { enum class Colocate; }; - // A NameMap is used to keep track of suffixes for names used in a scope. A - // name that has not been used so far in a scope will get no suffix. Later - // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes - // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. - typedef std::unordered_map NameMap; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); @@ -116,6 +122,17 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, scope_used_(nullptr), colocation_constraints_() {} +Scope::Impl::Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner) + : graph_(graph), + status_(status), + name_map_(name_map), + refiner_(refiner), + scope_used_(nullptr), + colocation_constraints_() {} + Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = @@ -254,9 +271,9 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, std::unordered_set Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set current_constraints(colocation_constraints_); - const NodeDef& node_def = colocate_with_op.node()->def(); + const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector node_constraints; - if (GetNodeAttr(node_def, kColocationAttrName, &node_constraints).ok()) { + if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); if (s.Consume(kColocationGroupPrefix)) { @@ -277,7 +294,7 @@ std::shared_ptr Scope::graph_as_shared_ptr() const { return impl()->graph_; } -Status Scope::status() const { return *impl()->status_; }; +Status Scope::status() const { return *impl()->status_; } const std::vector& Scope::control_deps() const { return impl()->control_deps_; @@ -464,4 +481,26 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } +class InternalScope { + public: + // NewScope doesn't take ownership of the inputs. + static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; + for (const Node* node : graph->nodes()) { + (*name_map)[node->name()] = 0; + } + // We provide null destructors for these shared ptrs (except for name_map) + // since the caller owns them and doesn't want the scope to destroy them. + return Scope(new Scope::Impl( + std::shared_ptr(graph, [](Graph*) {}), + std::shared_ptr(status, [](Status*) {}), + std::shared_ptr(name_map), + std::shared_ptr(refiner, [](ShapeRefiner*) {}))); + } +}; + +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + return InternalScope::NewScope(graph, status, refiner); +} + } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index ce70da709630bd402be9c75b3f6a5d638cd4a588..ec3543772d8febfb35488311886f1a4e9586a53e 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -204,6 +204,7 @@ class Scope { const std::vector& control_deps() const; private: + friend class InternalScope; class Impl; std::unique_ptr impl_; Impl* impl() { return impl_.get(); } diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..f2a911877f0b036080876b348b6a82f2a45df13a --- /dev/null +++ b/tensorflow/cc/framework/scope_internal.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +class ShapeRefiner; + +// NewInternalScope returns a new scope which doesn't take ownership of +// graph, status, name_map, and refiner. +// This is intended to enable the C API (which are used by other language +// bindings) to create a Scope and access C++ functionality (i.e. gradients). +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 26abd2438e652f29a1d25caf689ab0606a12b00a..37f07e71a0dff9144f193679bbcfcf581c1538cf 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -43,9 +43,9 @@ Status PackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int N; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "N", &N)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N)); int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->reserve(N); auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis)); @@ -60,7 +60,7 @@ Status UnpackGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int axis; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "axis", &axis)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis)); grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis))); return scope.status(); } @@ -162,7 +162,7 @@ Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string message; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "message", &message)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); string err_msg = strings::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); @@ -215,9 +215,9 @@ Status ReverseSequenceGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { auto seq_lengths = op.input(1); int batch_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "batch_dim", &batch_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); int seq_dim; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "seq_dim", &seq_dim)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim)); grad_outputs->push_back( ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim, ReverseSequence::BatchDim(batch_dim))); @@ -267,7 +267,8 @@ Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( BatchToSpace(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -290,7 +291,8 @@ Status BatchToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back( SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size)); grad_outputs->push_back(NoGradient()); @@ -313,7 +315,8 @@ Status SpaceToDepthGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -323,7 +326,8 @@ Status DepthToSpaceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { int block_size; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "block_size", &block_size)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "block_size", &block_size)); grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size)); return scope.status(); } @@ -333,7 +337,7 @@ Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); @@ -346,7 +350,7 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string mode; - TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->def(), "mode", &mode)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); return scope.status(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index aff0653139538820a705371ee9446a3d38ca69b5..8c1a01f518f9ad3a4571c2f36c01d4eae712e813 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -21,6 +21,17 @@ namespace tensorflow { namespace ops { namespace { +// Conjugate helper function returns the conjugate of an Output if it +// is complex valued. +Output ConjugateHelper(const Scope& scope, const Output& out) { + DataType dtype = out.type(); + if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { + return Conj(scope, out); + } else { + return out; + } +} + // TODO(andydavis) Add control dependencies to gradient functions (as needed). Status AbsGrad(const Scope& scope, const Operation& op, @@ -44,9 +55,11 @@ REGISTER_GRADIENT_OP("Neg", NegGrad); Status InvGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (-1 * (y * y)) + // dy/dx = -1/x^2 = -y^2 + auto dydx = Neg(scope, Square(scope, op.output(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Inv", InvGrad); @@ -55,10 +68,12 @@ REGISTER_GRADIENT_OP("Reciprocal", InvGrad); Status SquareGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // dx = dy * (2 * x) + // dy/dx = (2 * x) auto two = Cast(scope, Const(scope, 2), op.input(0).type()); + auto dydx = Mul(scope, two, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Square", SquareGrad); @@ -68,11 +83,12 @@ Status SqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sqrt(x) // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) - // dx = dy * (0.5 * (1 / y)) auto y_inv = Reciprocal(scope, op.output(0)); auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, half, y_inv); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); @@ -82,14 +98,14 @@ Status RsqrtGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1/x^1/2 = x^-1/2 // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 - // dx = dy * (-1/2 * y * x^-1) auto x_inv = Reciprocal(scope, op.input(0)); auto y = op.output(0); auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); auto a = Mul(scope, neghalf, x_inv); - auto b = Mul(scope, a, y); - auto dx = Mul(scope, grad_inputs[0], b); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, a, y); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); @@ -97,10 +113,11 @@ REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); Status ExpGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // y = exp(x) - // dy/dx = exp(x) - // dx = dy * y - grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); + // dy/dx = exp(x) = y + // grad(x) = grad(y) * conj(dy/dx) + // = grad(y) * conj(y) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0)))); return scope.status(); } REGISTER_GRADIENT_OP("Exp", ExpGrad); @@ -108,10 +125,12 @@ REGISTER_GRADIENT_OP("Exp", ExpGrad); Status Expm1Grad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = expm1(x) - // df/dx = exp(x) - // dx = dy * exp(x) - grad_outputs->push_back(Mul(scope, grad_inputs[0], Exp(scope, op.input(0)))); + // y = expm1(x) + // dy/dx = exp(x) + auto dydx = Exp(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Expm1", Expm1Grad); @@ -119,11 +138,12 @@ REGISTER_GRADIENT_OP("Expm1", Expm1Grad); Status LogGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log(x) = y - // df/dx = 1 / x - // dx = dy * (1 / x) + // y = log(x) + // dy/dx = 1 / x + auto dydx = Reciprocal(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Mul(scope, grad_inputs[0], Reciprocal(scope, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log", LogGrad); @@ -131,12 +151,13 @@ REGISTER_GRADIENT_OP("Log", LogGrad); Status Log1pGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - // f(x) = log1p(x) = y - // df/dx = 1 / (1 + x) - // dx = dy * (1 / (1 + x)) + // y = log1p(x) + // dy/dx = 1 / (1 + x) auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Reciprocal(scope, Add(scope, one, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) grad_outputs->push_back( - Div(scope, grad_inputs[0], Add(scope, one, op.input(0)))); + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Log1p", Log1pGrad); @@ -146,11 +167,12 @@ Status TanhGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = tanh(x) // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 - // dx = dy * (1 - y^2) auto y2 = Square(scope, op.output(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); - grad_outputs->push_back(dx); + auto dydx = Sub(scope, one, y2); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -160,11 +182,13 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = 1 / (1 + exp(-x)) // dy/dx = y * (1 - y) - // dx = dy * y * (1 - y) auto y = op.output(0); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); - auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); - grad_outputs->push_back(dx); + auto dydx = Mul(scope, y, Sub(scope, one, y)); + // dx = dy * y * (1 - y) + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -185,9 +209,10 @@ Status SinGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = sin(x) // dy/dx = cos(x) - // dx = dy * cos(x) - auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); - grad_outputs->push_back(dx); + auto dydx = Cos(scope, op.input(0)); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Sin", SinGrad); @@ -197,9 +222,10 @@ Status CosGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = cos(x) // dy/dx = -sin(x) - // dx = dy * -sin(x) - auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); - grad_outputs->push_back(dx); + auto dydx = Neg(scope, Sin(scope, op.input(0))); + // grad(x) = grad(y) * conj(dy/dx) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx))); return scope.status(); } REGISTER_GRADIENT_OP("Cos", CosGrad); @@ -208,12 +234,12 @@ Status AsinGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { // y = asin(x) - // dy/dx = 1 / (1 - x * x)^1/2 - // dx = dy * (1 / (1 - x * x)^1/2) + // dy/dx = 1 / sqrt(1 - x^2) auto x2 = Square(scope, op.input(0)); auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -239,9 +265,9 @@ Status TanGrad(const Scope& scope, const Operation& op, std::vector* grad_outputs) { // y = tan(x) // dy/dx = sec(x)^2 = 1 / cos(x)^2 - // dx = dy * (1 / cos(x)^2) auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0)))); - auto dx = Mul(scope, grad_inputs[0], dydx); + // grad(x) = grad(y) * conj(dy/dx) + auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)); grad_outputs->push_back(dx); return scope.status(); } @@ -324,7 +350,7 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, const string& attr_adj_x, const string& attr_adj_y, std::vector* grad_outputs) { DataType dtype; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->attrs(), "T", &dtype)); if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { return errors::Unimplemented( "MatMul gradient for complex data type is not supported yet."); @@ -332,8 +358,10 @@ Status MatMulGradCommon(const Scope& scope, const Operation& op, bool ta; bool tb; - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta)); - TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_x, &ta)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), attr_adj_y, &tb)); if (!ta && !tb) { return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, op.input(1), diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index d7278929d4651f17d25670934b15e6da33d6a960..de6baa176936bcda7d0899c3795e1fbd37627058 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -56,23 +56,25 @@ class CWiseUnaryGradTest : public ::testing::Test { ATAN }; - void TestCWiseGrad(UnaryOpType op_type, std::function x_fn, - std::function dy_fn, - std::function dx_fn) { - Tensor x(DT_FLOAT, {2, 3, 2}); - auto x_flat = x.flat(); + template + void TestCWiseGrad(UnaryOpType op_type, const std::function& x_fn, + const std::function& dy_fn, + const std::function& dx_fn) { + DataType dtype = DataTypeToEnum::v(); + Tensor x(dtype, {2, 3, 2}); + auto x_flat = x.flat(); for (int i = 0; i < x_flat.size(); ++i) { x_flat(i) = x_fn(i); } - Tensor dy(DT_FLOAT, {2, 3, 2}); - auto dy_flat = dy.flat(); + Tensor dy(dtype, {2, 3, 2}); + auto dy_flat = dy.flat(); for (int i = 0; i < dy_flat.size(); ++i) { dy_flat(i) = dy_fn(x_flat(i)); } - Tensor dx(DT_FLOAT, {2, 3, 2}); - auto dx_flat = dx.flat(); + Tensor dx(dtype, {2, 3, 2}); + auto dx_flat = dx.flat(); for (int i = 0; i < dx_flat.size(); ++i) { dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); } @@ -146,7 +148,19 @@ class CWiseUnaryGradTest : public ::testing::Test { test::ExpectClose(output, dx); } - float RV(std::vector v) { return v[random::New64() % v.size()]; } + float RV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 CRV(const std::vector& v) { + return v[random::New64() % v.size()]; + } + + complex64 conjugate(const complex64& val) { + return complex64(val.real(), -val.imag()); + } + + const complex64 one_{1.0, 0}; Scope scope_; }; @@ -155,14 +169,14 @@ TEST_F(CWiseUnaryGradTest, Abs) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return x * dy; }; - TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Neg) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return -dy; }; - TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Reciprocal) { @@ -171,14 +185,36 @@ TEST_F(CWiseUnaryGradTest, Reciprocal) { auto dx_fn = [this](const float x, const float dy) { return -(1 / (x * x)) * dy; }; - TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64 x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64 x, const complex64 dy) { + return -conjugate(one_ / (x * x)) * dy; + }; + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Square) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; - TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Square_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(2, 0) * x) * dy; + }; + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sqrt) { @@ -187,7 +223,18 @@ TEST_F(CWiseUnaryGradTest, Sqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * 0.5 * (1.0 / std::sqrt(x)); }; - TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; + }; + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Rsqrt) { @@ -196,7 +243,18 @@ TEST_F(CWiseUnaryGradTest, Rsqrt) { auto dx_fn = [this](const float x, const float dy) { return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); }; - TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; + }; + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Exp) { @@ -205,7 +263,18 @@ TEST_F(CWiseUnaryGradTest, Exp) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Exp_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Expm1) { @@ -214,14 +283,36 @@ TEST_F(CWiseUnaryGradTest, Expm1) { auto dx_fn = [this](const float x, const float dy) { return dy * std::exp(x); }; - TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Expm1_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::exp(x)); + }; + TestCWiseGrad(EXPM1, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log) { auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; - TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log_Complex) { + auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(one_ / x); + }; + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Log1p) { @@ -230,7 +321,20 @@ TEST_F(CWiseUnaryGradTest, Log1p) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / (1.0 + x)); }; - TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log1p_Complex) { + auto x_fn = [this](const int i) { + return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + conjugate(x)); + }; + TestCWiseGrad(LOG1P, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Tanh) { @@ -240,7 +344,21 @@ TEST_F(CWiseUnaryGradTest, Tanh) { const float y = std::tanh(x); return dy * (1.0 - y * y); }; - TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tanh_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = std::tanh(x); + return dy * conjugate((one_ - y * y)); + }; + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sigmoid) { @@ -250,14 +368,28 @@ TEST_F(CWiseUnaryGradTest, Sigmoid) { const float y = 1.0 / (1.0 + std::exp(-x)); return dy * y * (1.0 - y); }; - TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 y = one_ / (one_ + std::exp(-x)); + return dy * conjugate(y * (one_ - y)); + }; + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sign) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; auto dx_fn = [this](const float x, const float dy) { return 0.0; }; - TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Sin) { @@ -266,7 +398,20 @@ TEST_F(CWiseUnaryGradTest, Sin) { auto dx_fn = [this](const float x, const float dy) { return dy * std::cos(x); }; - TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(std::cos(x)); + }; + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Cos) { @@ -275,7 +420,20 @@ TEST_F(CWiseUnaryGradTest, Cos) { auto dx_fn = [this](const float x, const float dy) { return dy * -1.0 * std::sin(x); }; - TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy * conjugate(-std::sin(x)); + }; + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); } TEST_F(CWiseUnaryGradTest, Asin) { @@ -284,7 +442,24 @@ TEST_F(CWiseUnaryGradTest, Asin) { auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asin_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Enable test when the asin kernel supports complex numbers + if (false) { + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Acos) { @@ -293,7 +468,24 @@ TEST_F(CWiseUnaryGradTest, Acos) { auto dx_fn = [this](const float x, const float dy) { return dy * (-1.0 / std::sqrt(1.0 - x * x)); }; - TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acos_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / -conjugate(std::sqrt(one_ - x * x)); + }; + // TODO(kbsriram) + // Add test when the acos kernel supports complex numbers + if (false) { + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Tan) { @@ -303,7 +495,25 @@ TEST_F(CWiseUnaryGradTest, Tan) { const float cosx = std::cos(x); return dy * (1 / (cosx * cosx)); }; - TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + const complex64 cosx = std::cos(x); + return dy / conjugate(cosx * cosx); + }; + // TODO(kbsriram) + // Enable when tan kernel supports complex inputs + if (false) { + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); + } } TEST_F(CWiseUnaryGradTest, Atan) { @@ -312,7 +522,24 @@ TEST_F(CWiseUnaryGradTest, Atan) { auto dx_fn = [this](const float x, const float dy) { return dy * (1 / (1 + x * x)); }; - TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atan_Complex) { + auto x_fn = [this](const int i) { + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + }; + auto dy_fn = [this](const complex64& x) { + return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); + }; + auto dx_fn = [this](const complex64& x, const complex64& dy) { + return dy / (one_ + x * x); + }; + // TODO(kbsriram) + // Add test when the atan kernel supports complex numbers + if (false) { + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); + } } class CWiseUnaryComplexGradTest : public ::testing::Test { diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 5a4770f879ff9a1422a63a88bd2b67ba201a0567..3184edeb3307cafcbfbc41c6477fd092ab613b46 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -28,9 +28,9 @@ void ExpectNodeEqual(const Node* n, gtl::ArraySlice values, TensorShape shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(tensor.dtype(), dtype); test::ExpectTensorEqual(tensor, test::AsTensor(values, shape)); } @@ -39,9 +39,9 @@ void ExpectTypeAndShape(const Node* n, DataType expected_dtype, TensorShape expected_shape) { EXPECT_TRUE(n->IsConstant()); Tensor tensor; - TF_EXPECT_OK(GetNodeAttr(n->def(), "value", &tensor)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "value", &tensor)); DataType dtype; - TF_EXPECT_OK(GetNodeAttr(n->def(), "dtype", &dtype)); + TF_EXPECT_OK(GetNodeAttr(n->attrs(), "dtype", &dtype)); EXPECT_EQ(dtype, expected_dtype); EXPECT_EQ(expected_shape, TensorShape(tensor.shape())); } diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index cd94ddf4a1b67d3b98da7769db95bbda294e76db..1dffb10c03379571907e921c1add98d1f11625c3 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -22,7 +22,7 @@ op { name: "Where" input_rename: { from: "input" to: "condition" } } op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true } # control_flow_ops -# TODO(josh11b): Hide Switch and Merge once we write and migrate users to +# TODO(joshl): Hide Switch and Merge once we write and migrate users to # a Cond() API. #op { name: "Switch" hide: true } #op { name: "Merge" hide: true } diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index b144bfc33e46c3db192cfb1e3ef8a0633e9fa519..908aa01a3470b67233c61d150ea955c1c13a8cd3 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -36,7 +36,7 @@ auto* load_attempt_count = monitoring::Counter<2>::New( "status"); auto* load_latency = monitoring::Counter<1>::New( "/tensorflow/cc/saved_model/load_latency", - "Latency in microseconds for SavedModels that were succesfully loaded.", + "Latency in microseconds for SavedModels that were successfully loaded.", "model_path"); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 4618c932c310eefe775ccf9d8c38fbe1eea702ca..fe45931f7f802bf483d39ea02ee280b38b8d894c 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -116,17 +116,13 @@ void Coordinator::WaitForStop() { } Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const { - RunMetadata tmp_metadata; - { - mutex_lock l(runners_lock_); - for (auto& t : runners_) { - Status s = t->ExportRunMetadata(&tmp_metadata); - if (!s.ok()) { - return s; - } + mutex_lock l(runners_lock_); + for (auto& t : runners_) { + Status s = t->ExportCostGraph(cost_graph); + if (!s.ok()) { + return s; } } - cost_graph->MergeFrom(tmp_metadata.cost_graph()); return Status::OK(); } diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 632418c5ca5f523defe781a780ca0987202f59e4..0e01b19cd98bc797b7bb25da55c05d96f3eb93c7 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -36,8 +36,8 @@ class RunnerInterface { public: virtual ~RunnerInterface() {} virtual Status Join() = 0; - virtual Status ExportRunMetadata(RunMetadata* metadata) const { - return Status(error::INVALID_ARGUMENT, "No RunMetadata to export."); + virtual Status ExportCostGraph(CostGraphDef* cost_graph) const { + return Status(error::INVALID_ARGUMENT, "No cost model to export."); } /// Returns true iff the runner is running, i.e. if it is trying to populate /// its queue. diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 6b615916813519d7eaa94e69e846dcbfb87623bc..5aaaa116cf00dac6c1de3056c6121913a23acd77 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -49,7 +49,12 @@ Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { enqueue_op_names_.insert(enqueue_op_names_.end(), queue_runner_def.enqueue_op_name().begin(), queue_runner_def.enqueue_op_name().end()); - runs_ = enqueue_op_names_.size(); + size_t op_names_size = enqueue_op_names_.size(); + if (op_names_size > kint32max) { + return Status(error::INVALID_ARGUMENT, + "Enqueue ops to run cannot exceed kint32max"); + } + runs_ = static_cast(op_names_size); if (runs_ == 0) { return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run."); } @@ -82,9 +87,9 @@ QueueRunner::~QueueRunner() { Status QueueRunner::Start(Session* sess) { return Start(sess, 0); } -Status QueueRunner::StartAndCollectRunMetadata(Session* sess, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(sess, 0); } @@ -115,10 +120,9 @@ Status QueueRunner::Start(Session* sess, int wait_for) { return Status::OK(); } -Status QueueRunner::StartAndCollectRunMetadata(Session* session, - int wait_for_ms, - const RunOptions* run_options) { - SetRunArgumentsAndRunMetadata(run_options); +Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options) { + SetRunArgumentsAndCostGraph(run_options); return Start(session, wait_for_ms); } @@ -127,7 +131,7 @@ void QueueRunner::Stop(Session* sess) { coord_->WaitForStop(); } if (!cancel_op_name_.empty()) { - UpdateStatus(RealRun(sess, cancel_op_name_)); + UpdateStatus(RealRun(sess, cancel_op_name_, false)); } stopped_ = true; } @@ -162,7 +166,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { if (coord_ && coord_->ShouldStop()) { break; } - status = RealRun(sess, enqueue_op); + status = RealRun(sess, enqueue_op, true); if (first_iteration) { if (!status.ok()) { mutex_lock l(mu_); @@ -183,9 +187,11 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { // will be run anway in this case. if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) { if (last_run && !close_op_name_.empty()) { - UpdateStatus(RealRun(sess, close_op_name_)); + UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { + LOG(ERROR) << "Queue runner thread got a failure status: " + << status.ToString(); UpdateStatus(status); if (coord_) { coord_->RequestStop().IgnoreError(); @@ -198,34 +204,35 @@ Status QueueRunner::GetStatus() { return status_; } -Status QueueRunner::ExportRunMetadata(RunMetadata* metadata) const { - if (!rm_mu_) { +Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const { + if (!cg_mu_) { return Status(error::FAILED_PRECONDITION, - "This QueueRunner doesn't collect and store RunMetadata."); + "This QueueRunner doesn't collect a cost graph."); } - mutex_lock l(*rm_mu_); - metadata->MergeFrom(*run_metadata_); + mutex_lock l(*cg_mu_); + cost_graph->MergeFrom(*cost_graph_); return Status::OK(); } -void QueueRunner::SetRunArgumentsAndRunMetadata(const RunOptions* run_options) { - rm_mu_.reset(new mutex()); +void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions* run_options) { + cg_mu_.reset(new mutex()); { - mutex_lock l(*rm_mu_); - run_metadata_.reset(new RunMetadata()); + mutex_lock l(*cg_mu_); + cost_graph_.reset(new CostGraphDef()); } if (run_options) { run_options_ = *run_options; } } -Status QueueRunner::RealRun(Session* sess, const string& op) { +Status QueueRunner::RealRun(Session* sess, const string& op, + bool update_costs) { Status s; - if (rm_mu_) { + if (update_costs && cg_mu_) { RunMetadata metadata; s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata); - mutex_lock l(*rm_mu_); - run_metadata_->MergeFrom(metadata); + mutex_lock l(*cg_mu_); + cost_graph_->Swap(metadata.mutable_cost_graph()); } else { s = sess->Run({}, {}, {op}, nullptr); } diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index c69f28793a95990901961e835e004b019b98dbdc..71ed44c9c6064a4e0e4a61a8e2e649e7a8a235ec 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -60,15 +60,15 @@ class QueueRunner : public RunnerInterface { Status Start(Session* sess); /// Starts the queue runner with the given session and sets the run arguments - /// for sess->Run. It also collects and stores the run metedata. - Status StartAndCollectRunMetadata(Session* sess, - const RunOptions* run_options = nullptr); + /// for sess->Run. It also collects and stores the cost model. + Status StartAndCollectCostGraph(Session* sess, + const RunOptions* run_options = nullptr); /// Starts the queue runner with the given session, and wait for up to the /// specified time (in milliseconds) for the queues to start to fill up. Status Start(Session* sess, int wait_for_ms); - Status StartAndCollectRunMetadata(Session* session, int wait_for_ms, - const RunOptions* run_options = nullptr); + Status StartAndCollectCostGraph(Session* session, int wait_for_ms, + const RunOptions* run_options = nullptr); /// Requests to stop and runs the cancel op. It would be called in a separate /// thread when coordinator is set. If there is no coordinator it should be @@ -82,11 +82,11 @@ class QueueRunner : public RunnerInterface { /// Returns the latest status. Status GetStatus(); - // Returns the stored run metadata. - Status ExportRunMetadata(RunMetadata* metadata) const override; + // Returns the stored cost model. + Status ExportCostGraph(CostGraphDef* cost_graph) const override; private: - QueueRunner() : coord_(nullptr), stopped_(false), rm_mu_(nullptr) {} + QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} // Initializes the instance with the QueueRunnerDef proto. Status Init(const QueueRunnerDef& queue_runner_def); @@ -105,9 +105,9 @@ class QueueRunner : public RunnerInterface { bool IsRunning() const override { return !stopped_; } - void SetRunArgumentsAndRunMetadata(const RunOptions* run_options); + void SetRunArgumentsAndCostGraph(const RunOptions* run_options); - Status RealRun(Session* sess, const string& op); + Status RealRun(Session* sess, const string& op, bool update_costs); string queue_name_; std::vector enqueue_op_names_; @@ -130,8 +130,8 @@ class QueueRunner : public RunnerInterface { mutex cb_mu_; std::vector> callbacks_; - mutable std::unique_ptr rm_mu_; - std::unique_ptr run_metadata_ GUARDED_BY(rm_mu_); + mutable std::unique_ptr cg_mu_; + std::unique_ptr cost_graph_ GUARDED_BY(cg_mu_); RunOptions run_options_; }; diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index c37a69a7f76b6d83634d0b01e2038c4e6b4fa22e..da2fc03b6c07ef3dec26434eaae8e3f70c07c5f1 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -44,6 +44,7 @@ using ops::FIFOQueue; using ops::QueueClose; using ops::QueueDequeue; using ops::QueueEnqueue; +using ops::RandomNormal; using ops::Square; using ops::Variable; @@ -84,7 +85,7 @@ QueueRunnerDef BuildQueueRunnerDef( const std::string& close_op, const std::string& cancel_op, const std::vector& queue_closed_error_codes) { QueueRunnerDef queue_runner_def; - *queue_runner_def.mutable_queue_name() = kQueueName; + *queue_runner_def.mutable_queue_name() = queue_name; for (const std::string& enqueue_op : enqueue_ops) { *queue_runner_def.mutable_enqueue_op_name()->Add() = enqueue_op; } @@ -345,37 +346,54 @@ TEST(QueueRunnerTest, CallbackCalledOnError) { } TEST(QueueRunnerTest, RunMetaDataTest) { + Scope root = Scope::NewRootScope(); + auto q0 = FIFOQueue(root.WithOpName(kQueueName), {DataType::DT_FLOAT}); + Output rnd = RandomNormal(root.WithOpName("rnd"), {1, 1}, DataType::DT_FLOAT); + Output square = Square(root.WithOpName(kSquareOpName), rnd); + auto enqueue0 = QueueEnqueue(root.WithOpName(kEnqueueOp0), q0, {square}); + auto close0 = QueueClose(root.WithOpName(kCloseOp0), q0); + auto cancel0 = QueueClose(root.WithOpName(kCancelOp0), q0, + QueueClose::CancelPendingEnqueues(true)); + auto dequeue0 = + QueueDequeue(root.WithOpName(kDequeueOp0), q0, {DataType::DT_FLOAT}); + + GraphDef graph_def; + TF_EXPECT_OK(root.ToGraphDef(&graph_def)); + for (auto& node : *graph_def.mutable_node()) { + node.set_device("/cpu:0"); + } SessionOptions sess_options; sess_options.config.mutable_graph_options()->set_build_cost_model(1); std::unique_ptr session(NewSession(sess_options)); - GraphDef graph_def = BuildSimpleGraph(); TF_CHECK_OK(session->Create(graph_def)); - TF_CHECK_OK(session->Run({}, {}, {kAssignOpName}, nullptr)); - RunOptions run_options; - run_options.set_trace_level(RunOptions::HARDWARE_TRACE); - - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( - kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); + QueueRunnerDef queue_runner_def = + BuildQueueRunnerDef(kQueueName, {kEnqueueOp0}, kCloseOp0, kCancelOp0, {}); std::unique_ptr qr; TF_EXPECT_OK(QueueRunner::New(queue_runner_def, &qr)); - TF_CHECK_OK(qr->StartAndCollectRunMetadata(session.get(), &run_options)); + RunOptions run_options; + TF_CHECK_OK(qr->StartAndCollectCostGraph(session.get(), &run_options)); - TF_EXPECT_OK(qr->Join()); - RunMetadata run_metadata; - TF_CHECK_OK(qr->ExportRunMetadata(&run_metadata)); + // Make sure there was at least one element enqueued in q0: this prevents a + // race condition where we close the queue before it was populated. + std::vector dq0; + TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); + // Second call to run dequeue op is to make sure the cost graph has been + // stored. + TF_EXPECT_OK(session->Run({}, {kDequeueOp0}, {}, &dq0)); + + CostGraphDef cost_graph; + TF_CHECK_OK(qr->ExportCostGraph(&cost_graph)); + EXPECT_TRUE(cost_graph.node_size() > 0); - EXPECT_TRUE(run_metadata.has_cost_graph()); + qr->Stop(session.get()); } TEST(QueueRunnerTest, NoRunMetaDataTest) { GraphDef graph_def = BuildSimpleGraph(); auto session = BuildSessionAndInitVariable(graph_def); - RunOptions run_options; - run_options.set_trace_level(RunOptions::HARDWARE_TRACE); - QueueRunnerDef queue_runner_def = BuildQueueRunnerDef( kQueueName, {kCountUpToOpName}, kSquareOpName, "", {}); std::unique_ptr qr; @@ -383,8 +401,8 @@ TEST(QueueRunnerTest, NoRunMetaDataTest) { TF_CHECK_OK(qr->Start(session.get())); TF_EXPECT_OK(qr->Join()); - RunMetadata run_metadata; - EXPECT_EQ(qr->ExportRunMetadata(&run_metadata).code(), + CostGraphDef cost_graph; + EXPECT_EQ(qr->ExportCostGraph(&cost_graph).code(), error::FAILED_PRECONDITION); } diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index f2ecd2eddc28da94ac1c2404c02324e7782831c3..49d3cca3a4e2cc1aa16af2ac251b16b7a45753b1 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -227,7 +227,7 @@ int main(int argc, char* argv[]) { argv[dst++] = f; } argv[dst++] = nullptr; - argc = unknown_flags.size() + 1; + argc = static_cast(unknown_flags.size() + 1); tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::example::ConcurrentSessions(opts); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index c52a56b6428fb8a8415ed53477ba3e81c57b0ded..c12005a4cab903c15a4f95efa0fdc3b8b2563942 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -73,7 +73,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 042a72745a78c4a11b22c85e3a094d78c4ab2ed5..bbdb342a623f5d4435e437fbb94e282b685751c9 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -152,8 +152,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { str_util::ReplaceAllPairs(&code, rewrites); - str_util::ReplaceAll(&code, "{{NAME}}", name); - return code; + return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); } // Generate methods for args (inputs). @@ -366,7 +365,7 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 46d7c03006a1344df17fc99c8b837f31ee86feb9..01963c6df4682ec8c23a93201d7fbbab63558060 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -15,7 +15,7 @@ #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace Eigen { class ThreadPoolDevice; } +namespace Eigen { struct ThreadPoolDevice; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 1284155c07b1a253d42e7641354626eb153f0c35..0c7b97b01f43ea255ed4b7773ab5268396e7c306 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -203,14 +203,14 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config, for (const Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { string feed_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == kRetvalOp) { string fetch_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", fetch_id); @@ -234,7 +234,7 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); auto insert_result = indexed_arg_nodes.insert({index, n}); if (!insert_result.second) { const Node* dup = insert_result.first->second; @@ -264,9 +264,9 @@ Status CreateXlaArgs(const Graph& graph, for (const Node* node : arg_nodes) { XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } return Status::OK(); @@ -274,8 +274,8 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, - const FunctionLibraryDefinition* flib_def, +Status ConvertGraphToXla(xla::CompileOnlyClient* client, + std::unique_ptr graph, xla::Computation* computation, bool* has_context_arg) { // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); @@ -289,18 +289,19 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; + compiler_options.flib_def = &graph->flib_def(); + compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; XlaCompiler compiler(compiler_options); - std::unique_ptr flib_run(NewFunctionLibraryRuntime( - compiler.device_mgr(), Env::Default(), compiler.device(), - graph->versions().producer(), flib_def, OptimizerOptions())); XlaCompiler::CompilationResult result; - TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph), - flib_run.get(), xla_args, &result)); + TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "tfcompile", std::move(graph), + xla_args, &result)); *has_context_arg = result.requires_runtime_context; - *computation = std::move(result.computation); + *computation = std::move(*result.computation); int num_const_results = 0; for (int i = 0; i < result.outputs.size(); ++i) { @@ -334,7 +335,8 @@ Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr graph, } // Compiles the XLA computation into executable code. -Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, +Status CompileXla(xla::CompileOnlyClient* client, + const xla::Computation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -351,7 +353,7 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::LocalClient::AheadOfTimeComputationInstance instance; + xla::CompileOnlyClient::AotComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -366,17 +368,17 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = - xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); + xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); return Status::OK(); } } // namespace Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, const FunctionLibraryDefinition* flib, - std::unique_ptr* graph) { + const MainFlags& flags, std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); - std::unique_ptr g(new Graph(flib)); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); + std::unique_ptr g(new Graph(flib_def)); GraphDef copy_def(graph_def); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©_def, *g->op_registry(), 0 /*node_offset*/)); @@ -388,7 +390,6 @@ Status InitGraph(const GraphDef& graph_def, const Config& config, } Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, - const FunctionLibraryDefinition* flib, CompileResult* compile_result) { // Converts the graph into an XLA computation, and compiles the // computation. @@ -396,11 +397,11 @@ Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, namespace gpu = perftools::gputools; gpu::Platform* cpu_platform = gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); - xla::LocalClient* client = - xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie(); + xla::CompileOnlyClient* client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) + .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib, - &computation, + TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, &compile_result->has_context_arg)); if (!flags.debug_dir.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 8e9c64820baf0cb672cead59954098f10a9c9a32..e929272b2e4760e39cddba7e585cb12a7d2d7e98 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -56,8 +56,7 @@ extern const char* const kDebugNameAttr; // compute the outputs. If dump_graphs is true, graph rewrites will be dumped // for debugging. Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, const FunctionLibraryDefinition* flib, - std::unique_ptr* graph); + const MainFlags& flags, std::unique_ptr* graph); // CompileResult describes the output of CompileGraph, where the object file // data and meta-information is available in aot. @@ -83,7 +82,6 @@ struct CompileResult { // // The XLA compilation options are specified in the flags. Status CompileGraph(std::unique_ptr graph, const MainFlags& flags, - const FunctionLibraryDefinition* flib, CompileResult* result); } // namespace tfcompile diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 208de5498dbee6773683ac1aa2b33400a8a21f35..5772776666129ed55a479c8917e69df3f3ce2fc0 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,6 +31,8 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); +#elif defined(COMPILER_MSVC) + return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; // posix_memalign requires that the requested alignment be at least @@ -45,7 +47,13 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { #endif } -inline void aligned_free(void* aligned_memory) { free(aligned_memory); } +inline void aligned_free(void* aligned_memory) { +#if defined(COMPILER_MSVC) + _aligned_free(aligned_memory); +#else + free(aligned_memory); +#endif +} size_t align_to(size_t n, size_t align) { return (((n - 1) / align) + 1) * align; diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index ecb071a416c330065b286c41467c302df40714db..59d13e5393445330ba5f1c5a54b73de6b3b4c0d8 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -51,6 +51,7 @@ genrule( "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", + "test_graph_tffunction.pb", ], cmd = "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], @@ -114,6 +115,15 @@ tf_library( tags = ["manual"], ) +tf_library( + name = "test_graph_tffunction", + testonly = 1, + config = "test_graph_tffunction.config.pbtxt", + cpp_class = "FunctionComp", + graph = "test_graph_tffunction.pb", + tags = ["manual"], +) + cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -122,6 +132,7 @@ cc_test( ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9279c45f3738e6b667bff5928849491f9d97dada..98c13958d3729bc6c7f554630e236892be130a4a 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -95,6 +96,17 @@ def tfmatmulandadd(_): math_ops.add(x, y, name='x_y_sum') +def tffunction(_): + + @function.Defun(dtypes.int32, dtypes.int32) + def test_func(a, b): + return a + b + + x = constant_op.constant([1], name='x_const') + y = constant_op.constant([2], name='y_const') + test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -112,6 +124,7 @@ def main(_): write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) if __name__ == '__main__': @@ -121,7 +134,6 @@ if __name__ == '__main__': '--out_dir', type=str, default='', - help='Output directory for graphs, checkpoints and savers.' - ) + help='Output directory for graphs, checkpoints and savers.') FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..eb9c1cacb7ffe1ad60d985a2e5a1846707191fe7 --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tfcompile.Config proto. +feed { + id { node_name: "x_const" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_const" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "func_call" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f57d2859dfa4979fe0b04efea734817462af3bbf..76343b9752199fc4d26e4988452cd3c055bb5d96 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" @@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) { } } +TEST(TFCompileTest, Function) { + // The function is equivalent to an addition + FunctionComp add_fn; + EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]); + EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]); + + add_fn.arg0() = 1; + add_fn.arg1() = 2; + EXPECT_TRUE(add_fn.Run()); + EXPECT_EQ(add_fn.error_msg(), ""); + EXPECT_EQ(add_fn.result0(), 3); + EXPECT_EQ(add_fn.result0_data()[0], 3); + EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 64e5bfd602cb2898dcbe57bfa0949c954f17acc1..7d61bee8caf7edcbbc1fa3cc1c79d7b5af2c942c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -282,5 +282,6 @@ def target_llvm_triple(): "//tensorflow:android_arm": "armv7-none-android", "//tensorflow:android_arm64": "aarch64-none-android", "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 85ef9560bbf1a7130dd6b140d552d96c2a0e21d6..4b7e22076937808334726d9f67c086696eab1b73 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -52,7 +52,8 @@ const char kUsageHeader[] = "header file that gives access to the functionality in the object file.\n" "A typical invocation looks like this:\n" "\n" - " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n" + " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt " + "--cpp_class=\"mynamespace::MyComputation\"\n" "\n"; Status ReadProtoFile(const string& kind, const string& fname, @@ -73,6 +74,9 @@ void ParseTensorId(const string& name, TensorId* id) { Status Main(const MainFlags& flags) { // Process config. Config config; + if (flags.config.empty()) { + return errors::InvalidArgument("Must specify --config"); + } TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { @@ -85,15 +89,16 @@ Status Main(const MainFlags& flags) { } // Read and initialize the graph. + if (flags.graph.empty()) { + return errors::InvalidArgument("Must specify --graph"); + } GraphDef graph_def; TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); std::unique_ptr graph; - FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); - TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph)); + TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph)); CompileResult compile_result; - TF_RETURN_IF_ERROR( - CompileGraph(std::move(graph), flags, &flib, &compile_result)); + TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result)); // Write output files. Env* env = Env::Default(); @@ -101,6 +106,9 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + if (flags.cpp_class.empty()) { + return errors::InvalidArgument("Must specify --cpp_class"); + } TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, &header_opts.namespaces)); string header; @@ -131,12 +139,16 @@ int main(int argc, char** argv) { QCHECK(parsed_flags_ok) << "\n" << usage; tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(argc == 1 && !flags.config.empty() && - (flags.dump_fetch_nodes || - (!flags.graph.empty() && !flags.entry_point.empty()))) - << "\n" - << usage; - - TF_QCHECK_OK(tensorflow::tfcompile::Main(flags)); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + tensorflow::Status status = tensorflow::tfcompile::Main(flags); + if (status.code() == tensorflow::error::INVALID_ARGUMENT) { + std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" + << usage; + return 1; + } else { + TF_QCHECK_OK(status); + } return 0; } diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc index 108ab1eab7bf3b087e8049c5b24d652d871789c8..c321d3ff4c779fbd2e9c67dfc1eb24c734a9103f 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/aot/tfcompile_util_test.cc @@ -24,7 +24,7 @@ namespace tensorflow { namespace tfcompile { namespace { -void ExpectErrorContains(Status status, StringPiece str) { +void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) << "expected error: " << status.error_message() << " to contain: " << str; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c16fe56122fca8cf8a88d6098b2374285f33e9f2..04f15a6a0b44cdbc54dea3d2963047bbcff1be77 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -18,7 +18,23 @@ package( default_visibility = [":internal"], ) +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") + +# This target can be used by XLA device plugins to prevent circular +# dependencies, and provides access to all of the required headers +# for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -29,6 +45,7 @@ cc_library( ":xla_cpu_jit", ":xla_gpu_device", ":xla_gpu_jit", + "//tensorflow/compiler/plugin", ], alwayslink = 1, ) @@ -48,12 +65,12 @@ cc_library( cc_library( name = "xla_gpu_jit", visibility = [":friends"], - deps = [ + deps = if_cuda([ ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_local_launch_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", - ], + ]), alwayslink = 1, ) @@ -125,7 +142,6 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -133,9 +149,9 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:tensorflow_opensource", - "//tensorflow/core/kernels:assign_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:identity_op", @@ -176,22 +192,33 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "graph_to_functiondef", + srcs = ["graph_to_functiondef.cc"], + hdrs = ["graph_to_functiondef.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "compilation_passes", srcs = [ "build_xla_launch_ops_pass.cc", "encapsulate_subgraphs_pass.cc", - "graph_to_functiondef.cc", "mark_for_compilation_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", "encapsulate_subgraphs_pass.h", - "graph_to_functiondef.h", "mark_for_compilation_pass.h", ], deps = [ ":common", + ":graph_to_functiondef", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/kernels:xla_local_launch_op", @@ -222,6 +249,7 @@ cc_test( deps = [ ":common", ":compilation_passes", + ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index abb68f73d7e3870f733c350be0dc99ab21a6b083..48eed7fce07f0855934600890e157b2752d38838 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -66,9 +66,9 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { int num_constant_args, num_resource_args; TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args)); + GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); TF_RETURN_IF_ERROR( - GetNodeAttr(node->def(), kXlaNumResourceArgsAttr, &num_resource_args)); + GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); if (num_constant_args < 0 || num_resource_args < 0 || num_constant_args + num_resource_args > node->num_inputs()) { @@ -88,7 +88,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->def().device(), const_dtypes, num_resource_args, arg_dtypes, + node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, node->output_types(), graph, &launch_node)); launch_node->set_assigned_device_name(node->assigned_device_name()); @@ -173,7 +173,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, FunctionLibraryRuntime::Handle handle; // If ndef is not instantiable, e.g., the function does not exist, // simply bail out. - TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle)); + TF_RETURN_IF_ERROR( + flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); const FunctionBody* fbody = flr->GetFunctionBody(handle); CHECK(fbody); // Can't be nullptr since we just instantiated it. std::vector const_args(fbody->arg_types.size()); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 1d2793d3c55f4436a07e4f632887561202d0498e..88ec45f8d86643aa4f7c643ac5bee333fb2ec559 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -88,9 +88,12 @@ class Encapsulator { // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. + // If 'reuse_existing_functions' is set, use an existing function with the + // same name, if any. // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // function conversion. Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library); // Write a copy of the input graph to 'graph_out', where the subgraphs are @@ -162,7 +165,7 @@ static const char* const kRetValOp = "_Retval"; // none. string Encapsulator::GetFunctionNameAttr(Node const* node) const { string attr; - if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) { + if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { attr.clear(); } return attr; @@ -192,7 +195,7 @@ Status Encapsulator::SplitIntoSubgraphs() { // Check the device matches any existing device. string device = node->assigned_device_name().empty() - ? node->def().device() + ? node->requested_device() : node->assigned_device_name(); if (subgraph.device.empty()) { @@ -236,9 +239,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // Create a new _Retval node DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(src_subgraph.graph->NewName("output")); + ret_def.set_name(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_retval")); AddNodeAttr("T", dtype, &ret_def); AddNodeAttr("index", ret_index, &ret_def); Node* ret = src_subgraph.graph->AddNode(ret_def, &s); @@ -263,8 +273,16 @@ Status Encapsulator::SplitIntoSubgraphs() { // This is the first time we have seen this tensor. Create an _Arg node. DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported: tensor ", + edge->src()->name(), ":", edge->src_output()); + } + NodeDef arg_def; - NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp); + NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", + edge->src_output(), "_arg"), + kArgOp); builder.Attr("T", dtype); builder.Attr("index", arg_index); s = builder.Finalize(&arg_def); @@ -291,11 +309,11 @@ Status Encapsulator::SplitIntoSubgraphs() { } Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // For each subgraph, build a FunctionDef. for (auto& subgraph_entry : subgraphs_) { - const string& name = subgraph_entry.first; + string name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; subgraph.call_node_def.set_op(name); @@ -332,6 +350,8 @@ Status Encapsulator::BuildFunctionDefs( for (auto& result : subgraph.results) { result.second = output_permutation[result.second]; } + + name = subgraph.call_node_def.op(); } FunctionDef fdef; @@ -346,7 +366,9 @@ Status Encapsulator::BuildFunctionDefs( strings::StrCat("encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } } return Status::OK(); } @@ -545,14 +567,16 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Status EncapsulateSubgraphsInFunctions( string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), &graph_in); s = encapsulator.SplitIntoSubgraphs(); if (!s.ok()) return s; - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library); + s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, + reuse_existing_functions, library); if (!s.ok()) return s; std::unique_ptr out(new Graph(library)); @@ -569,7 +593,7 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { for (Node* n : graph.nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= types->size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -586,7 +610,7 @@ static Status RenumberArguments(Graph* graph, for (Node* n : graph->nodes()) { if (n->type_string() == kArgOp) { int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); if (index < 0 || index >= permutation.size()) { return errors::InvalidArgument("Invalid argument number"); } @@ -674,7 +698,8 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, &graph_out, library)); + flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, + &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, @@ -688,7 +713,7 @@ Status EncapsulateSubgraphsPass::Run( bool IsXlaCompiledKernel(const Node& node) { bool is_compiled = false; bool has_compilation_attr = - GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() && + GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && is_compiled; return has_compilation_attr ? is_compiled : false; } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 3ca7dfbf6a0ec29d9517139ffb952298d503cabc..b0987f76c91ed48df52fab303ea6052ebd8fd336 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -34,6 +34,8 @@ namespace tensorflow { // 'input_permutation' and 'output_permutation' are initialized to the identity // permutation. 'nodedef' is the NodeDef for the call to the function under // construction, provided to allow additional attributes to be set. +// The rewrite may also change the NodeDef's operator name, and that +// name will be used as the name of the generated function. typedef std::function* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> @@ -53,6 +55,9 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index faab7bd3d25d2491cf74faeb3b06acf4c2d6a054..a8869c8e2a7c164f97917cdae312289efb8b2663 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -76,7 +76,7 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \ do { \ string diff; \ - EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \ + EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \ << diff << "\nActual: " << actual.DebugString(); \ } while (false) @@ -109,7 +109,7 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", a, b, opts); } -Node* AddNLike(std::vector inputs, +Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest", @@ -144,8 +144,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, - /* rewrite_subgraph_fn= */ {}, - /* parallel_checking= */ false, + /*rewrite_subgraph_fn=*/{}, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -205,12 +206,12 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, - {{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, }, - {{"output__2", "c:o:0"}}); + {{"c_0_retval", "c:o:0"}}); { std::unique_ptr lib_def( @@ -261,17 +262,17 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { *library_expected.add_function() = test::function::XTimesTwo(); *library_expected.add_function() = FunctionDefHelper::Create( - "F1", {"input__0:float"}, {"output__1:float"}, {}, + "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {}, { - {{"C"}, "UnaryTest", {"input__0"}}, + {{"C"}, "UnaryTest", {"a_0_arg"}}, }, - {{"output__1", "C:o:0"}}); + {{"c_0_retval", "C:o:0"}}); *library_expected.add_function() = FunctionDefHelper::Create( - "F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {}, + "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {}, { - {{"D"}, "BinaryTest", {"input__0", "input__1"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}}, }, - {{"output__2", "D:o:0"}}); + {{"d_0_retval", "D:o:0"}}); { std::unique_ptr lib_def( @@ -340,7 +341,8 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, &graph, &library)); + /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -371,7 +373,8 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, &graph, &library)); + /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, + &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index ce943471fb07fe02f18596247ccfddb94bd35158..83c23385008d56859b81abee7d292276036a45ee 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -126,8 +126,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kArgOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().input_arg_size() <= index) { fdef->mutable_signature()->add_input_arg(); } @@ -143,8 +143,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, if (node->type_string() == kRetValOp) { int index; DataType type; - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &type)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &index)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); while (fdef->signature().output_arg_size() <= index) { fdef->mutable_signature()->add_output_arg(); } @@ -161,9 +161,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } NodeDef* node_def = fdef->add_node_def(); - node_def->CopyFrom(node->def()); + *node_def = node->def(); node_def->set_name(node_names.Uniquify(node->name())); - node_def->clear_device(); // Reset input names based on graph rather than the NodeDef. node_def->clear_input(); @@ -204,8 +203,8 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, // Populate tensor_renaming. NameRangeMap output_ranges; - TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr, - &output_ranges)); + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { for (int i = output.second.first; i < output.second.second; ++i) { const string tensor_name = strings::StrCat( diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index c741ccfb31efa8794ae745e2e52e3c91b20cfcfc..29c5ff724299ec84d31268c4227259ec02d10742 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -34,7 +34,7 @@ namespace tensorflow { namespace { -Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { +Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** cache) { XlaDevice::Metadata* metadata; Status s = rm->Lookup(rm->default_container(), "xla_metadata", &metadata); @@ -42,12 +42,8 @@ Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) { return s; } core::ScopedUnref metadata_ref(metadata); - XlaCompiler::Options options; - options.device_type = metadata->jit_device_type(); - options.client = metadata->client(); - options.allow_cpu_custom_calls = false; - options.local_executable_has_hybrid_result = false; - *compiler = new XlaCompilationCache(options); + *cache = + new XlaCompilationCache(metadata->client(), metadata->jit_device_type()); return Status::OK(); } @@ -59,7 +55,7 @@ XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); function_ = *func; VLOG(1) << "XlaDeviceLaunch created function=" - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); DataTypeVector constant_types; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); num_constant_args_ = constant_types.size(); @@ -85,29 +81,37 @@ std::vector SnapshotResourceVariables(OpKernelContext* ctx, void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaDeviceLaunch::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - XlaCompilationCache* compiler; + XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [rm](XlaCompilationCache** compiler) { - return BuildCompilationCache(rm, compiler); + rm->default_container(), "xla_compiler", &cache, + [rm](XlaCompilationCache** cache) { + return BuildCompilationCache(rm, cache); })); // Holds the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); + core::ScopedUnref cache_ref(cache); std::vector variables = SnapshotResourceVariables(ctx, num_resource_args_); + XlaCompiler::Options options; + options.client = cache->client(); + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = false; + options.local_executable_has_hybrid_result = false; + const XlaCompiler::CompilationResult* kernel; - OP_REQUIRES_OK(ctx, compiler->Compile(function_, num_constant_args_, - variables, ctx, &kernel, nullptr)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + variables, ctx, &kernel, nullptr)); VLOG(1) << "XLA compilation complete..."; @@ -117,7 +121,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Runs the computation, if any. There might not be a computation if all // outputs were compile-time constants. std::vector> outputs; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); // Builds the inputs to the computation. @@ -148,8 +152,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; - auto result = compiler->client()->Execute(kernel->computation, arg_ptrs, - &execution_options, &profile); + auto result = cache->client()->Execute(*kernel->computation, arg_ptrs, + &execution_options, &profile); auto elapsed = env->NowMicros() - start_time; OP_REQUIRES(ctx, result.ok(), result.status()); @@ -158,7 +162,7 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) { auto outputs_or_error = - compiler->client()->DeconstructTuple(*result.ValueOrDie()); + cache->client()->DeconstructTuple(*result.ValueOrDie()); OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status()); outputs = outputs_or_error.ConsumeValueOrDie(); } else { diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc index 8b43c7c1564a340b70e8cfa271a3ef50379b46bc..40acc0d81d08230b373823e333cd5e3e407b9c4f 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc @@ -148,24 +148,28 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) OP_REQUIRES(ctx, num_resource_args == 0, errors::Unimplemented( "XlaLocalLaunchOp does not support resource variables")); -} - -Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { - gpu::Platform::Id platform_id; if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id = gpu::host::kHostPlatformId; + platform_id_ = gpu::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id = gpu::cuda::kCudaPlatformId; + platform_id_ = gpu::cuda::kCudaPlatformId; } else { - return errors::InvalidArgument("Unknown device type for local _XlaLaunch"); + ctx->SetStatus( + errors::InvalidArgument("Unknown device type for local _XlaLaunch")); + return; } +} - auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id); +Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { + auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { return StreamExecutorUtil::ConvertStatus(platform.status()); } - auto client = - xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); if (!client.ok()) { return client.status(); } @@ -175,18 +179,14 @@ Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { return errors::InvalidArgument("No JIT device registered for ", device_type_.type()); } - XlaCompiler::Options options; - options.device_type = DeviceType(registration->compilation_device_name); - options.client = client.ValueOrDie(); - options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; - *compiler = new XlaCompilationCache(options); + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); return Status::OK(); } void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " - << Canonicalize(function_.name(), function_.attr()); + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); @@ -195,23 +195,31 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - XlaCompilationCache* compiler; + XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, rm->LookupOrCreate( - rm->default_container(), "xla_compiler", &compiler, - [this](XlaCompilationCache** compiler) { - return BuildCompilationCache(compiler); + rm->default_container(), "xla_cache", &cache, + [this, ctx](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but // this is more obviously correct.) - core::ScopedUnref compiler_ref(compiler); + core::ScopedUnref cache_ref(cache); + + xla::LocalClient* client = static_cast(cache->client()); - xla::LocalClient* client = static_cast(compiler->client()); + XlaCompiler::Options options; + options.client = client; + options.device_type = &cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, compiler->Compile(function_, num_constant_args_, {}, ctx, - &kernel, &executable)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, {}, + ctx, &kernel, &executable)); VLOG(1) << "Executing XLA Computation..."; @@ -221,7 +229,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { std::unique_ptr output; bool output_is_tuple; - if (!kernel->computation.IsNull()) { + if (!kernel->computation->IsNull()) { // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector> arg_buffers; arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); @@ -260,8 +268,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(&xla_allocator); - run_options.set_inter_op_thread_pool( - ctx->device()->tensorflow_cpu_worker_threads()->workers); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); Env* env = Env::Default(); auto start_time = env->NowMicros(); diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h index 8023206762951a4dafba900dd291f2ee9bdbbdf3..5e4d3336a91001fac1d222709f64300e777247c7 100644 --- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -43,11 +44,15 @@ class XlaLocalLaunchOp : public OpKernel { private: // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(XlaCompilationCache** compiler); + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** compiler); DeviceType device_type_; NameAttrList function_; int num_constant_args_; + + perftools::gputools::Platform::Id platform_id_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 22dbf7ec99fe93fb7fe8c524a3dc84ac1a97f015..73c4e80551485189d1e43fd93eed39083bd6b6b7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -50,22 +50,24 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { } // Make sure we don't recurse infinitely on recursive functions. -const int kMaxRecursionDepth = 5; +const int kMaxRecursionDepth = 10; -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime); +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime); -// Tests whether 'while_def' is a completely compilable loop. +// Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. -bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_def.op(); +bool IsCompilableWhile(const Node& while_node, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { + VLOG(2) << "Loop marking: " << while_node.type_string(); const NameAttrList* name_attr; NodeDef call; Status status; - status = GetNodeAttr(while_def, "cond", &name_attr); + status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'cond' attribute on While node."; return false; @@ -78,7 +80,7 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, VLOG(2) << "Can't compile loop condition: " << cond_func; return false; } - status = GetNodeAttr(while_def, "body", &name_attr); + status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { VLOG(2) << "Missing 'body' attribute on While node."; return false; @@ -98,8 +100,9 @@ bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type, // Tests whether 'call_def' is a call to a completely compilable function. // Every operator in the function must be compilable for a function to be // compilable. -bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, - int depth, FunctionLibraryRuntime* lib_runtime) { +bool IsCompilableCall(const NodeDef& call_def, + const DeviceType& jit_device_type, int depth, + FunctionLibraryRuntime* lib_runtime) { VLOG(2) << "Function marking: " << call_def.op(); if (depth > kMaxRecursionDepth) { @@ -109,7 +112,7 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, FunctionLibraryRuntime::Handle handle; Status status = - lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle); + lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; return false; @@ -131,11 +134,11 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type, for (Node* node : fbody->graph->nodes()) { if (node->IsSource() || node->IsSink()) continue; - if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue; - if (node->def().op() == "While") { + if (node->type_string() == "_Arg" || node->type_string() == "_Retval") + continue; + if (node->type_string() == "While") { // Handle functional While loop (not in open source build). - return IsCompilableWhile(node->def(), jit_device_type, depth + 1, - lib_runtime); + return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); } if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, @@ -189,17 +192,16 @@ Status FindCompilationCandidates( if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) { VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceArgument(*node)) { VLOG(2) << "Compilation rejected node: resource argument " << node->name() - << ": " << node->def().op(); + << ": " << node->type_string(); continue; } - if (node->def().op() == "While" && - !IsCompilableWhile(node->def(), jit_device_type, 0, - lib_runtime.get())) { + if (node->type_string() == "While" && + !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime.get())) { continue; } candidates->insert(node); @@ -316,10 +318,10 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; - Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile); + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); if (status.ok()) return compile; - status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile); + status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; // Otherwise use the value of global_jit_level. @@ -482,8 +484,8 @@ Status MarkForCompilationPass::RunImpl( // all nodes marked with _XlaCompile=true to also have a // _XlaScope property set (and raise an error otherwise); but // for now we don't do this. - if (GetNodeAttr(node_from->def(), kXlaScopeAttr, &from_scope).ok() && - GetNodeAttr(node_to->def(), kXlaScopeAttr, &to_scope).ok() && + if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && from_scope != to_scope) { continue; } @@ -538,10 +540,9 @@ Status MarkForCompilationPass::RunImpl( // Compile if the user marked this node _XlaCompile=true bool compile_attr = false; bool marked_for_compilation = false; - if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) { + if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { marked_for_compilation = compile_attr; - } else if (options.flib_def - ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr) + } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) .ok()) { marked_for_compilation = compile_attr; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 91e4a2b41c7026b6ca028ed6a7e61588d57e9e50..9f30e12e0e30fef6b4bcd0ea3c091842b008c29a 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -57,7 +57,7 @@ std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { string cluster; - if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) { + if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { CHECK(!cluster.empty()); ids[node->name()] = cluster; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 41abea02eb2d17423744dfb719ee9a3f6b8f1198..63ca77f9a912acce2078f3da43d64f2e10049380 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -37,9 +37,9 @@ limitations under the License. namespace tensorflow { -XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options) - : compiler_(options) {} - +XlaCompilationCache::XlaCompilationCache(xla::Client* client, + DeviceType device_type) + : client_(client), device_type_(std::move(device_type)) {} XlaCompilationCache::~XlaCompilationCache() = default; string XlaCompilationCache::DebugString() { @@ -95,7 +95,7 @@ Status XlaCompilationCache::BuildSignature( const NameAttrList& function, int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, Signature* signature) { - signature->name = Canonicalize(function.name(), function.attr()); + signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); signature->arg_values.resize(num_constant_args); signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); @@ -205,8 +205,9 @@ Status BuildArguments(int num_constant_args, } // namespace Status XlaCompilationCache::Compile( - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, OpKernelContext* ctx, + const XlaCompiler::Options& options, const NameAttrList& function, + int num_constant_args, const std::vector& variable_args, + OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); @@ -263,21 +264,18 @@ Status XlaCompilationCache::Compile( TF_RETURN_IF_ERROR( BuildArguments(num_constant_args, variable_args, ctx, &args)); - std::unique_ptr flr(NewFunctionLibraryRuntime( - compiler_.device_mgr(), ctx->env(), compiler_.device(), - TF_GRAPH_DEF_VERSION, - ctx->function_library()->GetFunctionLibraryDefinition(), - OptimizerOptions(), nullptr /* custom_kernel_creator */)); - + XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler_.CompileFunction( - flr.get(), function, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, + &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { if (entry->executable == nullptr && - !entry->compilation_result.computation.IsNull()) { - entry->compilation_status = compiler_.BuildExecutable( + !entry->compilation_result.computation->IsNull()) { + XlaCompiler compiler(options); + entry->compilation_status = compiler.BuildExecutable( entry->compilation_result, &entry->executable); } *executable = entry->executable.get(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index ff67e48d1a9a9f16881c2e141b23ce8c479aef50..4ffcb68a3220b2354a3542e4c2a4d3e000969e0b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -46,7 +46,7 @@ struct OptionalTensor { // bound. class XlaCompilationCache : public ResourceBase { public: - explicit XlaCompilationCache(const XlaCompiler::Options& options); + XlaCompilationCache(xla::Client* client, DeviceType device_type); ~XlaCompilationCache() override; // Compiles a function into a XlaCompiler::CompilationResult that can be used @@ -61,19 +61,21 @@ class XlaCompilationCache : public ResourceBase { // xla::LocalExecutable and sets `executable to point to it. The resulting // executable pointer may be null if the computation has no non-constant // outputs. - Status Compile(const NameAttrList& function, int num_constant_args, + Status Compile(const XlaCompiler::Options& options, + const NameAttrList& function, int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable); - xla::Client* client() const { return compiler_.client(); } + xla::Client* client() const { return client_; } + const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: - XlaCompiler compiler_; - std::unique_ptr function_library_runtime_; + xla::Client* const client_; + const DeviceType device_type_; // Describes the types, shapes and any compile-time constant arguments // to a kernel. Key that uniquely identifies a compilation output. diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 3c6793b89420ed61259070f7bf637d6f4aa097d0..5e336c5287bd9e2067e93cd8db8a5a1b62b62bd2 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -108,12 +109,23 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + Metadata** metadata) { + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) { + return errors::Internal("No resource manager."); + } + TF_RETURN_IF_ERROR( + rm->Lookup(rm->default_container(), "xla_metadata", metadata)); + return Status::OK(); +} + XlaDevice::XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, perftools::gputools::Platform* platform, Allocator* xla_allocator) - : LocalDevice(options, attrs, xla_allocator), + : LocalDevice(options, attrs), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(xla_allocator), @@ -163,6 +175,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); + // When TraceMe profiling is off (which is the default), the + // following TraceMe constructor is simply a conditional test of + // false value. Measurements show that its overhead is negligible. + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); op_kernel->Compute(context); } @@ -170,6 +186,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); op_kernel->ComputeAsync(context, done); } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3de14f306168937bb0483e0c442984a02e2b1442..0badb390c6b7785b36f58c786e1d32a8d10d7c29 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -67,6 +67,10 @@ class XlaDevice : public LocalDevice { perftools::gputools::Platform* platform_; // Not owned. }; + // Sets `*metadata` to the XlaDevice Metadata in the resource manager of + // `ctx`. + static Status GetMetadata(OpKernelContext* ctx, Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index a52239df252b2b556987fa9701f43047765c60de..8699006ebc5aacafd46046a7c3f093356f687280 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -63,30 +63,10 @@ class XlaDeviceDummyOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ - REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ - ControlTriggerOp); \ - REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ - REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ - REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ - NextIterationOp); \ - REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ - SwitchOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ - REGISTER_KERNEL_BUILDER(Name("LoopCond") \ - .Device(DEVICE) \ - .HostMemory("input") \ - .HostMemory("output"), \ - IdentityOp); \ - \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp); -// TODO(b/32507444): the registrations for the control flow operators are -// temporary and exist primarily to work around a bug in the graph partitioning -// code. - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..056f2228ca64b083bf05a8728dc25213e99e4cd8 --- /dev/null +++ b/tensorflow/compiler/plugin/BUILD @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Configuration file for an XLA plugin. +- please don't check in changes to this file +- to prevent changes appearing in git status, use: + git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD + +To add additional devices to the XLA subsystem, add targets to the +dependency list in the 'plugin' target. For instance: + + deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], +""" + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "plugin", + deps = [], +) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 03e255e6b842668a491d254953926500ce3a50ec..19f7ff835456855a2b2ab7d5856f1d3e6f7f9733 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -65,6 +65,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "adam_test", + size = "small", + srcs = ["adam_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "binary_ops_test", size = "small", @@ -156,6 +170,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "slice_ops_test", + size = "small", + srcs = ["slice_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "function_test", size = "small", @@ -305,6 +332,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "spacetobatch_op_test", + size = "medium", + srcs = ["spacetobatch_op_test.py"], + shard_count = 3, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "ternary_ops_test", size = "small", diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index 0a2c9e26c6fbd827d5ab669dea5419f9fa50025b..a5c5885b4284aee167ae4cb18f7e42820c6d251d 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functional tests for aggregate operations.""" +"""Tests for Adagrad.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3215dc36e5b2d517aa951db1b0d41188185ef93a --- /dev/null +++ b/tensorflow/compiler/tests/adam_test.py @@ -0,0 +1,176 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Adam.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +def adam_update_numpy(param, + g_t, + t, + m, + v, + alpha=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8): + alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon) + return param_t, m_t, v_t + + +class AdamOptimizerTest(XLATestCase): + + def testBasic(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTensorLearningRate(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + variable_scope.get_variable_scope().set_use_resource(True) + + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = array_ops.placeholder(dtype) + grads1 = array_ops.placeholder(dtype) + opt = adam.AdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + else: + update2.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 9efdaee7ab66f7cfc84bc1c30a9ba700e268abe2..7221a0a3c745f939b88cae0f66af2421922dcd68 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -107,6 +107,12 @@ class BinaryOpsTest(XLATestCase): np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) + self._testBinary( + gen_nn_ops._elu_grad, + np.array([1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), + expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) + self._testBinary( gen_nn_ops._relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index e89c411d01f8eb27f39bf65f3d3d21ec817c3ddf..2660e1d5728caf88e2b9ae73b3e3fde2aee71ed8 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -116,13 +116,14 @@ class NAryOpsTest(XLATestCase): np.array([1, 1], dtype=np.int32)], expected=np.array([[], []], dtype=np.float32)) - if (np.int64 in self.int_types): - self._testNAry(lambda x: array_ops.strided_slice(*x), - [np.array([[], [], []], dtype=np.float32), - np.array([1, 0], dtype=np.int64), - np.array([3, 0], dtype=np.int64), - np.array([1, 1], dtype=np.int64)], - expected=np.array([[], []], dtype=np.float32)) + if np.int64 in self.int_types: + self._testNAry( + lambda x: array_ops.strided_slice(*x), [ + np.array([[], [], []], dtype=np.float32), np.array( + [1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64), + np.array([1, 1], dtype=np.int64) + ], + expected=np.array([[], []], dtype=np.float32)) self._testNAry(lambda x: array_ops.strided_slice(*x), [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index c3e8ff724c178c5e635fedae0c3295cf598b2b00..2a71543f3febe3cb692fdcd563772c3bd2d3724a 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -94,7 +94,7 @@ class OpTestBuilder { explicit OpTestBuilder(const string& op_name); // Adds an input 'tensor'. - OpTestBuilder& Input(Tensor tensor); + OpTestBuilder& Input(const Tensor& tensor); // Sets an attribute. template @@ -111,8 +111,8 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - Status BuildGraph(string name_prefix, string device, bool use_jit, - GraphDef* graphdef, NodeDef** test_node_def, + Status BuildGraph(const string& name_prefix, const string& device, + bool use_jit, GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const; @@ -127,7 +127,7 @@ OpTestBuilder::OpTestBuilder(const string& op_name) { node_def_.set_op(op_name); } -OpTestBuilder& OpTestBuilder::Input(Tensor tensor) { +OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) { VLOG(1) << "Adding input: " << tensor.DebugString(); inputs_.push_back(tensor); return *this; @@ -146,9 +146,9 @@ OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, return *this; } -Status OpTestBuilder::BuildGraph(string name_prefix, string device, - bool use_jit, GraphDef* graphdef, - NodeDef** test_node_def, +Status OpTestBuilder::BuildGraph(const string& name_prefix, + const string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, std::vector* inputs, std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); @@ -209,7 +209,7 @@ class OpTest : public ::testing::Test { // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs; // whichever happens first. - void Repeatedly(std::function fn); + void Repeatedly(const std::function& fn); // Select a random element from 'candidates'. template @@ -218,12 +218,11 @@ class OpTest : public ::testing::Test { static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 20LL; - // Returns a random dimension size. + // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); // Returns a random shape. The tensor has rank in the range [min_rank, - // max_rank). - // Each dimension has size [0, kDefaultMaxDimensionSize]. + // max_rank). Each dimension has size [min_size, max_size). std::vector RandomDims(int min_rank = 0, int max_rank = kDefaultMaxRank, int64 min_size = 0, @@ -316,7 +315,7 @@ OpTest::OpTest() { TF_CHECK_OK(session_->Create(def)); } -void OpTest::Repeatedly(std::function fn) { +void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; for (int i = 0; !HasFailure() && i < max_repetitions; ++i) { fn(); @@ -668,6 +667,9 @@ void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder, VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test"; return; } + for (const Tensor& expected : expected_outputs) { + VLOG(1) << "Expected: " << expected.DebugString(); + } VLOG(1) << "Running test graph"; TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs)); @@ -877,6 +879,79 @@ TEST_F(OpTest, BatchMatMul) { }); } +TEST_F(OpTest, BatchToSpace) { + Repeatedly([this]() { + const int num_block_dims = 2; + std::vector block_dims = + RandomDims(num_block_dims, num_block_dims, 0, 5); + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_size; + input_dims[1 + i] = block_dims[i]; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 4); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(crops) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, BatchToSpaceND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[0] *= block_dims[i]; + } + std::copy(block_multipliers.begin(), block_multipliers.end(), + input_dims.begin() + 1); + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector crop_vals; + std::uniform_int_distribution distribution(0, 3); + for (int i = 0; i < num_block_dims; ++i) { + // Chooses crop values; does not always choose legal values. + crop_vals.push_back(distribution(generator())); + crop_vals.push_back(distribution(generator())); + } + Tensor crops; + CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BatchToSpaceND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(crops) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, BiasAdd) { Repeatedly([this]() { auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank)); @@ -1214,6 +1289,23 @@ TEST_F(OpTest, DynamicStitch) { }); } +TEST_F(OpTest, Elu) { + Repeatedly([this]() { + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Elu").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, EluGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad") + .Input(RandomTensor(DT_FLOAT, dims)) + .Input(RandomTensor(DT_FLOAT, dims)) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Equal) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); @@ -2019,6 +2111,87 @@ TEST_F(OpTest, SoftplusGrad) { }); } +TEST_F(OpTest, SpaceToBatch) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(4, 4, 0, 5); + const int num_block_dims = 2; + int64 block_size = RandomDim(0, 4); + + std::vector input_dims(1 + num_block_dims + 1); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_size; + } + input_dims[1 + num_block_dims] = RandomDim(); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(paddings) + .Attr("T", DT_FLOAT) + .Attr("block_size", block_size)); + }); +} + +TEST_F(OpTest, SpaceToBatchND) { + Repeatedly([this]() { + std::vector block_dims = RandomDims(1, 3, 0, 5); + int num_block_dims = block_dims.size(); + std::vector remaining_dims = RandomDims(0, 3); + std::vector block_multipliers = + RandomDims(block_dims.size(), block_dims.size(), 0, 4); + + std::vector input_dims(1 + num_block_dims + remaining_dims.size()); + input_dims[0] = RandomDim(); + for (int i = 0; i < num_block_dims; ++i) { + input_dims[1 + i] = block_dims[i] * block_multipliers[i]; + } + std::copy(remaining_dims.begin(), remaining_dims.end(), + input_dims.begin() + 1 + num_block_dims); + + std::vector padding_vals; + std::uniform_int_distribution distribution(0, 7); + for (int i = 0; i < num_block_dims; ++i) { + int64 pad_before; + int64 pad_after; + do { + pad_before = distribution(generator()); + pad_after = distribution(generator()); + } while (pad_before + pad_after > input_dims[1 + i]); + input_dims[1 + i] -= pad_before + pad_after; + padding_vals.push_back(pad_before); + padding_vals.push_back(pad_after); + } + Tensor paddings; + CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), + TensorShape({num_block_dims, 2}))); + + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("SpaceToBatchND") + .Input(RandomTensor(DT_FLOAT, input_dims)) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) + .Input(paddings) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SparseMatMul) { Repeatedly([this]() { int64 x = RandomDim(); @@ -2339,6 +2512,14 @@ TEST_F(OpTest, ZerosLike) { }); } +TEST_F(OpTest, OnesLike) { + Repeatedly([this]() { + DataType type = Choose({DT_INT32, DT_FLOAT}); + ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type)); + }); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddf2ee0dcb2b5f514ff9820c07f7cc10609ff66 --- /dev/null +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -0,0 +1,145 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + + +class SliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.slice(i, [2], [4]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 3, 4, 5], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) + + + +class StridedSliceTest(XLATestCase): + + def test1D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2], [6], [2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([2, 4], result) + + def test1DNegtiveStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [6], [2], [-2]) + params = { + i: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([6, 4], result) + + def test3D(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[1, 9]], [[6, 4]]], result) + + def test3DNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 4, 10]) + with self.test_scope(): + o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0], + [4, 5, 2, 4, 3, 7, 6, 8, 9, 4]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [4, 3, 4, 5, 7, 6, 5, 3, 4, 5], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7], + [7, 1, 7, 1, 8, 1, 8, 1, 3, 1]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9], + [9, 9, 5, 5, 6, 6, 3, 3, 6, 6]]] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[9, 8], + [1, 1]], + [[2, 4], + [5, 7]]], result) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3b86c84b2b92089da0dfc0070a4a7b8a03c81a --- /dev/null +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -0,0 +1,266 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for SpaceToBatch and BatchToSpace ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import test + + +def space_to_batch_direct(input_array, block_shape, paddings): + """Direct Python implementation of space-to-batch conversion. + + This is used for tests only. + + Args: + input_array: N-D array + block_shape: 1-D array of shape [num_block_dims]. + paddings: 2-D array of shape [num_block_dims, 2]. + + Returns: + Converted tensor. + """ + input_array = np.array(input_array) + block_shape = np.array(block_shape) + num_block_dims = len(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + + padded = np.pad(input_array, + pad_width=([[0, 0]] + list(paddings) + [[0, 0]] * + (input_array.ndim - 1 - num_block_dims)), + mode="constant") + reshaped_padded_shape = [input_array.shape[0]] + output_shape = [input_array.shape[0] * np.prod(block_shape)] + for block_dim, block_shape_value in enumerate(block_shape): + reduced_size = padded.shape[block_dim + 1] // block_shape_value + reshaped_padded_shape.append(reduced_size) + output_shape.append(reduced_size) + reshaped_padded_shape.append(block_shape_value) + reshaped_padded_shape.extend(input_array.shape[num_block_dims + 1:]) + output_shape.extend(input_array.shape[num_block_dims + 1:]) + + reshaped_padded = padded.reshape(reshaped_padded_shape) + permuted_reshaped_padded = np.transpose(reshaped_padded, ( + list(np.arange(num_block_dims) * 2 + 2) + [0] + + list(np.arange(num_block_dims) * 2 + 1) + list( + np.arange(input_array.ndim - num_block_dims - 1) + 1 + num_block_dims + * 2))) + return permuted_reshaped_padded.reshape(output_shape) + + +class SpaceToBatchTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" + + def _testPad(self, inputs, paddings, block_size, outputs): + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + # outputs = space_to_batch(inputs) + placeholder = array_ops.placeholder(dtype) + x_tf = gen_array_ops._space_to_batch( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + x_tf = gen_array_ops._batch_to_space( + placeholder, paddings, block_size=block_size) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testOne(self, inputs, block_size, outputs): + paddings = np.zeros((2, 2), dtype=np.int32) + self._testPad(inputs, paddings, block_size, outputs) + + # [1, 2, 2, 1] <-> [4, 1, 1, 1] + def testSmallInput2x2(self): + x_np = [[[[1], [2]], [[3], [4]]]] + block_size = 2 + x_out = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + self._testOne(x_np, block_size, x_out) + + # [1, 2, 2, 1] <-> [1, 3, 3, 1] (padding) <-> [9, 1, 1, 1] + def testSmallInput2x2Pad1x0(self): + x_np = [[[[1], [2]], [[3], [4]]]] + paddings = np.array([[1, 0], [1, 0]], dtype=np.int32) + block_size = 3 + x_out = [[[[0]]], [[[0]]], [[[0]]], [[[0]]], [[[1]]], [[[2]]], [[[0]]], + [[[3]]], [[[4]]]] + self._testPad(x_np, paddings, block_size, x_out) + + # Test with depth larger than 1. + # [1, 2, 2, 3] <-> [4, 1, 1, 3] + def testDepthInput2x2(self): + x_np = [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] + block_size = 2 + x_out = [[[[1, 2, 3]]], [[[4, 5, 6]]], [[[7, 8, 9]]], [[[10, 11, 12]]]] + self._testOne(x_np, block_size, x_out) + + # Test for larger input dimensions. + # [1, 4, 4, 1] <-> [4, 2, 2, 1] + def testLargerInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]], [[9], [11]]], [[[2], [4]], [[10], [12]]], + [[[5], [7]], [[13], [15]]], [[[6], [8]], [[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Test with batch larger than 1. + # [2, 2, 4, 1] <-> [8, 1, 2, 1] + def testBatchInput2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]]], + [[[9], [10], [11], [12]], [[13], [14], [15], [16]]]] + block_size = 2 + x_out = [[[[1], [3]]], [[[9], [11]]], [[[2], [4]]], [[[10], [12]]], + [[[5], [7]]], [[[13], [15]]], [[[6], [8]]], [[[14], [16]]]] + self._testOne(x_np, block_size, x_out) + + # Tests for larger input spatial dimensions AND batch larger than 1, to ensure + # that elements are correctly laid out spatially and properly interleaved + # along the batch dimension. + # [2, 4, 4, 1] <-> [8, 2, 2, 1] + def testLargerInputBatch2x2(self): + x_np = [[[[1], [2], [3], [4]], [[5], [6], [7], [8]], + [[9], [10], [11], [12]], [[13], [14], [15], [16]]], + [[[17], [18], [19], [20]], [[21], [22], [23], [24]], + [[25], [26], [27], [28]], [[29], [30], [31], [32]]]] + x_out = [[[[1], [3]], [[9], [11]]], [[[17], [19]], [[25], [27]]], + [[[2], [4]], [[10], [12]]], [[[18], [20]], [[26], [28]]], + [[[5], [7]], [[13], [15]]], [[[21], [23]], [[29], [31]]], + [[[6], [8]], [[14], [16]]], [[[22], [24]], [[30], [32]]]] + block_size = 2 + self._testOne(x_np, block_size, x_out) + + +class SpaceToBatchNDTest(XLATestCase): + """Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops.""" + + def _testPad(self, inputs, block_shape, paddings, outputs): + block_shape = np.array(block_shape) + paddings = np.array(paddings).reshape((len(block_shape), 2)) + with self.test_session() as sess, self.test_scope(): + for dtype in self.float_types: + placeholder = array_ops.placeholder(dtype) + # outputs = space_to_batch(inputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + # inputs = batch_to_space(outputs) + placeholder = array_ops.placeholder(dtype) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) + self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + + def _testDirect(self, input_shape, block_shape, paddings): + inputs = np.arange(np.prod(input_shape), dtype=np.float32) + inputs = inputs.reshape(input_shape) + self._testPad(inputs, block_shape, paddings, + space_to_batch_direct(inputs, block_shape, paddings)) + + def testZeroBlockDimsZeroRemainingDims(self): + self._testPad( + inputs=[1, 2], + block_shape=[], + paddings=[], + outputs=[1, 2],) + + def testZeroBlockDimsOneRemainingDim(self): + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[], + paddings=[], + outputs=[[1, 2], [3, 4]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[1, 2], [3, 4]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[1, 2], [3, 4]]) + + def testZeroBlockDimsTwoRemainingDims(self): + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[], + paddings=[], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with a no-op block dim. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1], + paddings=[[0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + # Same thing, but with two no-op block dims. + self._testPad( + inputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + block_shape=[1, 1], + paddings=[[0, 0], [0, 0]], + outputs=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + + def testOneBlockDimZeroRemainingDims(self): + self._testPad( + inputs=[[1, 2, 3], [4, 5, 6]], + block_shape=[2], + paddings=[1, 0], + outputs=[[0, 2], [0, 5], [1, 3], [4, 6]]) + + def testOneBlockDimOneRemainingDim(self): + self._testPad( + inputs=[[[1, 11], [2, 21], [3, 31]], [[4, 41], [5, 51], [6, 61]]], + block_shape=[2], + paddings=[1, 0], + outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], + [[4, 41], [6, 61]]]) + + def testDirect(self): + # Test with zero-size remaining dimension. + self._testDirect( + input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + + # Test with zero-size blocked dimension. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + + # Test with padding up from zero size. + self._testDirect( + input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2], + paddings=[[1, 2], [0, 0], [3, 0]]) + + self._testDirect( + input_shape=[3, 3, 4, 5, 2], + block_shape=[3, 4, 2, 2], + paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + + self._testDirect( + input_shape=[3, 2, 2, 3, 4, 5, 2, 5], + block_shape=[1, 1, 3, 4, 2, 2, 1], + paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0], [0, 0]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index c96826fd0a64b2d8fb02da22cfdc72edbb674317..51d8786ce3d7148e6863be7e1557a8bb23153d63 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -182,6 +182,11 @@ class UnaryOpsTest(XLATestCase): [0.7310586, 0.880797, 0.95257413, 0.98201376]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, + np.array([-300, -150, 0, 150, 300], dtype=dtype), + expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.sqrt, np.array([[4, 9]], dtype=dtype), @@ -209,6 +214,11 @@ class UnaryOpsTest(XLATestCase): [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.elu, + np.array([[-1, 0, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), @@ -257,6 +267,11 @@ class UnaryOpsTest(XLATestCase): np.array([[4, 3], [2, 1]], dtype=dtype), expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.ones_like, + np.array([[4, 3], [2, 1]], dtype=dtype), + expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + def testLogicalOps(self): self._assertOpOutputMatchesExpected( math_ops.logical_not, diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index dcb9e2db2f8ca7ef6e89cb9c6493d15dcaacd46e..70dacd9de4b95dfb77986dfaf177c16b758406f1 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for XLA JIT compiler.""" +"""Tests for reading and writing variables.""" from __future__ import absolute_import from __future__ import division @@ -21,11 +21,14 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -36,6 +39,21 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer class VariableOpsTest(XLATestCase): """Test cases for resource variable operators.""" + def testOneWriteOneOutput(self): + # Regression test for a bug where computations with one non-constant + # output and one variable update were mishandled. + for dtype in self.numeric_types: + init = np.array([[1, 2], [3, 4]], dtype=dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + x = v.assign_add(p) + with ops.control_dependencies([x]): + y = v.read_value() + self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), + sess.run(y, {p: 1})) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" with self.test_session() as session: @@ -98,5 +116,68 @@ class VariableOpsTest(XLATestCase): self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) +class StridedSliceAssignChecker(object): + """Compares the results of a slice assignment using Tensorflow and numpy.""" + + def __init__(self, test, x, dtype): + self.dtype = dtype + self.test = test + self.x_np = np.array(x).astype(dtype) + + def __setitem__(self, index, value): + value = np.array(value).astype(self.dtype) + + with self.test.test_session() as sess, self.test.test_scope(): + x = constant_op.constant(self.x_np, dtype=self.dtype) + var = resource_variable_ops.ResourceVariable(x) + sess.run(variables.variables_initializer([var])) + val = sess.run(var[index].assign(value)) + # val_copy is used to check that tf.assign works equivalently to the + # assign method above. + val_copy = sess.run(state_ops.assign(var[index], value)) + valnp = np.copy(self.x_np) + valnp[index] = np.array(value) + self.test.assertAllEqual(val, valnp) + self.test.assertAllEqual(val_copy, valnp) + + +class SliceAssignTest(XLATestCase): + + def testSliceAssign(self): + for dtype in self.numeric_types: + checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], + dtype=dtype) + # No-op assignment + checker[:] = [[10, 20, 30], [40, 50, 60]] + # Checks trivial (1,1) shape tensor + checker[1:2, 1:2] = [[66]] + # shrink shape changes + checker[1:2, 1] = [66] + checker[1, 1:2] = [66] + checker[1, 1] = 66 + # newaxis shape changes + checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] + # shrink and newaxis + checker[None, None, 0, 0:1] = [[[99]]] + # Non unit strides + checker[::1, 1::-1] = [[3, 33], [4, 44]] + # degenerate interval + checker[8:10, 0] = [] + checker[8:10, 8:10] = [[]] + + # Assign vector to scalar (rank-0) using newaxis + checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis + checker2[None] = [6] # new axis + + def testUninitialized(self): + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "uninitialized variable"): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable([1, 2]) + sess.run(v[:].assign([1, 2])) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 1388a892ba5a1d07c05eedf277085099923ae901..f5c228f8305d740b994dadc34c93b4e0ae32d785 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -18,15 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -48,34 +43,6 @@ class XlaDeviceTest(test.TestCase): result = sess.run(w, {x: [1.5, 0.5]}) self.assertAllClose(result, [12., 2.], rtol=1e-3) - def testLoops(self): - """Tests that loops work on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - with ops.device("device:XLA_CPU:0"): - c = lambda i, _: math_ops.less(i, 5) - b = lambda i, x: (i + 1, x * 2.0 + 1.0) - _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) - - result = session.run(y, {x: np.float32(2)}) - self.assertAllClose(result, np.float32(95), rtol=1e-3) - - def testCond(self): - """Tests that tf.cond works on XLA devices.""" - - with session_lib.Session() as session: - x = array_ops.placeholder(dtypes.float32) - y = array_ops.placeholder(dtypes.float32) - c = array_ops.placeholder(dtypes.bool) - with ops.device("device:XLA_CPU:0"): - z = x + 1.0 - w = control_flow_ops.cond(c, lambda: z, lambda: y) - t = math_ops.add(z, w) - - result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True}) - self.assertAllClose(result, np.float32(6), rtol=1e-3) - if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7a18c1e3750afa276d6721ffea9a4d481cb37136..12537b9765469da6d906d556ff69685149e2cc32 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 53aa749a0a90bf3fad06ed4bc57c4327c5d24dcc..c4cbaebb258fc19552227a51de616429e2e6221b 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -35,6 +35,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Any", "reduction_indices"}, {"ArgMax", "dimension"}, {"AvgPoolGrad", "orig_input_shape"}, + {"BatchToSpace", "crops"}, + {"BatchToSpaceND", "block_shape"}, + {"BatchToSpaceND", "crops"}, {"BroadcastGradientArgs", "s0"}, {"BroadcastGradientArgs", "s1"}, {"Concat", "concat_dim"}, @@ -65,10 +68,16 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResourceStridedSliceAssign", "begin"}, + {"ResourceStridedSliceAssign", "end"}, + {"ResourceStridedSliceAssign", "strides"}, {"Reverse", "dims"}, {"ReverseV2", "axis"}, {"Slice", "begin"}, {"Slice", "size"}, + {"SpaceToBatch", "paddings"}, + {"SpaceToBatchND", "block_shape"}, + {"SpaceToBatchND", "paddings"}, {"Split", "split_dim"}, {"SplitV", "split_dim"}, {"SplitV", "size_splits"}, @@ -102,7 +111,7 @@ Status BackwardsConstAnalysis(const Graph& g, if (must_be_const.find(node) != must_be_const.end()) { if (node->type_string() == "_Arg") { int index; - status = GetNodeAttr(node->def(), "index", &index); + status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; compile_time_const_args->at(index) = true; return; @@ -118,8 +127,8 @@ Status BackwardsConstAnalysis(const Graph& g, if (range.first == range.second) return; NameRangeMap input_name_ranges; - status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges, - nullptr); + status = + NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; for (auto it = range.first; it != range.second; ++it) { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 2ee80a41e820b5ecc92816c84b6de9625f319b19..81b065689da4d8314c6ae9480d73745830fc31f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -14,18 +14,21 @@ tf_kernel_library( name = "xla_ops", srcs = [ "aggregate_ops.cc", + "arg_op.cc", "batch_matmul_op.cc", + "batchtospace_op.cc", "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", "cast_op.cc", "concat_op.cc", + "const_op.cc", "conv_ops.cc", "cwise_ops.cc", - "declaration_op.cc", "depthwise_conv_ops.cc", "diag_op.cc", "dynamic_stitch_op.cc", + "elu_op.cc", "fill_op.cc", "function_ops.cc", "identity_op.cc", @@ -49,6 +52,7 @@ tf_kernel_library( "shape_op.cc", "slice_op.cc", "softmax_op.cc", + "spacetobatch_op.cc", "split_op.cc", "strided_slice_op.cc", "tile_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc similarity index 56% rename from tensorflow/compiler/tf2xla/kernels/declaration_op.cc rename to tensorflow/compiler/tf2xla/kernels/arg_op.cc index be2ce038016e852e48c312e26bf959ca5b9215af..d6897d6e3313414a5fd781f8a71ce143d5db2614 100644 --- a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -23,58 +23,6 @@ limitations under the License. namespace tensorflow { namespace { -// This OpKernel implements the Constant Op for XLA JIT -// devices. It extracts the constant Tensor from the Proto at kernel -// construction time, and then every time the Constant Op is executed -// an expression containing the constant is compiled. -class ConstantDeclarationOp : public XlaOpKernel { - public: - explicit ConstantDeclarationOp(OpKernelConstruction* ctx) - : XlaOpKernel(ctx), tensor_(ctx->output_type(0)) { - const TensorProto* proto = nullptr; - OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); - // MakeTensorFromProto uses the cpu_allocator, so tensor_ is a - // "real" tensor backed by CPU memory, holding the value of the - // constant. - OP_REQUIRES_OK(ctx, MakeTensorFromProto(*proto, &tensor_)); - OP_REQUIRES( - ctx, ctx->output_type(0) == tensor_.dtype(), - errors::InvalidArgument( - "Type mismatch between value (", DataTypeString(tensor_.dtype()), - ") and dtype (", DataTypeString(ctx->output_type(0)), ")")); - } - - void Compile(XlaOpKernelContext* ctx) override { - ctx->SetConstantOutput(0, tensor_); - } - - private: - // Extract the value of the constant from the Proto during Op kernel - // construction. The constant must be stored in a Tensor allocated - // using the cpu_allocator so that it is backed by real memory. The - // OpKernelConstruction's default allocator is the JITAllocator - // which only allocates enough space for metadata for each Tensor. - static Status MakeTensorFromProto(const TensorProto& tensor_proto, - Tensor* tensor) { - Tensor parsed(tensor_proto.dtype()); - if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - tensor_proto.DebugString()); - } - *tensor = parsed; - return Status::OK(); - } - - // This is a "real" tensor backed by CPU memory, containing the - // constant values. - Tensor tensor_; - TF_DISALLOW_COPY_AND_ASSIGN(ConstantDeclarationOp); -}; - -// XLA_* devices also register a "real" Identity operator so we suppress the -// dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstantDeclarationOp); - // This OpKernel implements the _Arg Op for XLA JIT devices. It // associates its output with one of the arguments to a // subcomputation. diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb4bd47ee50090722801329466cc88d34cd2449b --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void BatchToSpace(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& crops) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(crops.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), + errors::InvalidArgument("crops should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(crops.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + const int64 batch_size = input_shape[0]; + + // Compute the product of the block_shape values. + int64 block_num_elems = 1; + for (int i = 0; i < block_rank; ++i) { + block_num_elems *= block_shape[i]; + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + // 1. Reshape `input` to `reshaped` of shape: + // [block_shape[0], ..., block_shape[M-1], + // batch / prod(block_shape), + // input_shape[1], ..., input_shape[N-1]] + + OP_REQUIRES( + ctx, batch_size % block_num_elems == 0, + errors::InvalidArgument("Input batch dimension (", batch_size, + ") is not divisible by product of block sizes (", + block_num_elems, ")")); + std::vector reshaped_shape(input_rank + block_rank); + std::copy(block_shape.begin(), block_shape.end(), reshaped_shape.begin()); + reshaped_shape[block_rank] = batch_size / block_num_elems; + std::copy(input_shape.begin() + 1, input_shape.end(), + reshaped_shape.begin() + block_rank + 1); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + + // 2. Permute dimensions of `reshaped` to produce `permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1], block_shape[0], + // ..., + // input_shape[M], block_shape[M-1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector permutation(reshaped_shape.size()); + permutation[0] = block_rank; + for (int i = 0; i < block_rank; ++i) { + permutation[1 + 2 * i] = block_rank + 1 + i; + permutation[1 + 2 * i + 1] = i; + } + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + + // 3. Reshape `permuted` to produce `reshaped_permuted` of shape + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0], + // ..., + // input_shape[M] * block_shape[M-1], + // + // input_shape[M+1], + // ..., + // input_shape[N-1]] + std::vector reshaped_permuted_shape(input_rank); + reshaped_permuted_shape[0] = batch_size / block_num_elems; + for (int i = 0; i < block_rank; ++i) { + reshaped_permuted_shape[1 + i] = block_shape[i] * input_shape[1 + i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_permuted_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle reshaped_permuted = + b->Reshape(permuted, reshaped_permuted_shape); + + // 4. Crop the start and end of dimensions `[1, ..., M]` of + // `reshaped_permuted` according to `crops` to produce the output of shape: + // [batch / prod(block_shape), + // + // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + // ..., + // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + // + // input_shape[M+1], ..., input_shape[N-1]] + std::vector start_indices(input_rank, 0); + std::vector end_indices = reshaped_permuted_shape; + for (int i = 0; i < block_rank; ++i) { + int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); + int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, + errors::InvalidArgument("Crops must be non-negative")); + start_indices[1 + i] = crop_start; + end_indices[1 + i] -= crop_end; + OP_REQUIRES( + ctx, start_indices[1 + i] <= end_indices[1 + i], + errors::InvalidArgument( + "Cropped size must be non-negative: start: ", crop_start, + " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); + } + xla::ComputationDataHandle output = + b->Slice(reshaped_permuted, start_indices, end_indices); + ctx->SetOutput(0, output); +} + +class BatchToSpaceNDOp : public XlaOpKernel { + public: + explicit BatchToSpaceNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, crops); + } +}; +REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); + +class BatchToSpaceOp : public XlaOpKernel { + public: + explicit BatchToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal crops; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &crops)); + + BatchToSpace(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, crops); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ad676e7a2bb3d3f28ecb98164323cbf1e32f61a9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class ConstOp : public XlaOpKernel { + public: + explicit ConstOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const TensorProto* proto = nullptr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); + proto_ = *proto; + OP_REQUIRES( + ctx, ctx->output_type(0) == proto_.dtype(), + errors::InvalidArgument("Type mismatch between value (", + DataTypeString(proto_.dtype()), ") and dtype (", + DataTypeString(ctx->output_type(0)), ")")); + OP_REQUIRES_OK(ctx, TensorShape::IsValidShape(proto_.tensor_shape())); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape(proto_.tensor_shape()); + + xla::ComputationBuilder* b = ctx->builder(); + + // To avoid blowups for large constants filled with the same value, + // recognize that case and emit a scalar broadcast instead. + if (shape.num_elements() > 1) { + switch (proto_.dtype()) { + case DT_BOOL: + if (proto_.bool_val_size() == 1) { + ctx->SetOutput(0, + b->Broadcast(b->ConstantR0(proto_.bool_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_FLOAT: + if (proto_.float_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.float_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_DOUBLE: + if (proto_.double_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.double_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_INT32: + if (proto_.int_val_size() == 1) { + ctx->SetOutput(0, + b->Broadcast(b->ConstantR0(proto_.int_val(0)), + shape.dim_sizes())); + return; + } + break; + case DT_INT64: + if (proto_.int64_val_size() == 1) { + ctx->SetOutput( + 0, b->Broadcast(b->ConstantR0(proto_.int64_val(0)), + shape.dim_sizes())); + return; + } + break; + default: + break; + } + } + + // General case + Tensor tensor(proto_.dtype()); + OP_REQUIRES(ctx, tensor.FromProto(cpu_allocator(), proto_), + errors::InvalidArgument("Cannot parse tensor from proto: ", + proto_.DebugString())); + ctx->SetConstantOutput(0, tensor); + } + + private: + TensorProto proto_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstOp); +}; + +// XLA_* devices also register a "real" Const operator so we suppress the +// dummy operator using CompilationOnly(). +REGISTER_XLA_OP(Name("Const").CompilationOnly(), ConstOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..62a5e1bd421a75fb0a8fa6eacd58e4aaa2f02236 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of XLA Elu Ops + +#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class EluOp : public XlaOpKernel { + public: + explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto pred = b->Gt(ctx->Input(0), zero); + const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); + } +}; + +class EluGradOp : public XlaOpKernel { + public: + explicit EluGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + // Return the lhs (incoming gradient) if the rhs (input feature) > 0, + // otherwise return lhs * (1 + rhs). + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + const auto zero = XlaHelpers::Zero(b, input_type(0)); + const auto one = XlaHelpers::One(b, input_type(0)); + const auto grad = ctx->Input(0); + const auto activation = ctx->Input(1); + const auto exp_grad = b->Mul(grad, b->Add(activation, one)); + const auto pred = b->Gt(activation, zero); + ctx->SetOutput(0, b->Select(pred, grad, exp_grad)); + } +}; + +REGISTER_XLA_OP(Name("Elu"), EluOp); +REGISTER_XLA_OP(Name("EluGrad"), EluGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc index d718f98545f66cb79a77d758a3fb7ee486d87b4b..8dacb6627bde516c92cb07b747207adbe85ada5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/function_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc @@ -68,7 +68,8 @@ class SymbolicGradientOp : public AsyncOpKernel { done); OP_REQUIRES_OK_ASYNC( - ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); + ctx, lib->Instantiate(kGradientOp, AttrSlice(&def().attr()), &handle_), + done); FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc index eff23bd77d23afc882c67f8168270d1cb4413977..691a0b972d5c09ad632d706d72a1b60988730986 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int32_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { tensorflow::gather_float_int32_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc index ae31f6f2006959c03941a1eb04b31aecf52424b0..3dff6e2737bf1af7f5d646928e740fa895692a03 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/gather_functor.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -63,7 +64,6 @@ EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { // Implements gather on CPU. This is called by an XLA custom call, set up by // gather_op.cc. -extern "C" void __attribute__((visibility("default"))) -gather_float_int64_xla_impl(float* out, void** data) { +extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { tensorflow::gather_float_int64_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 0033a949a372684caadce70bf46a996a942e9ec4..afbd64ca5038378d48744d6d773e0dfb1376e1f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -43,7 +44,6 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_1d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index be8ad2317c9ba6a39f839c4a535440fb94365aa9..841ff2f4df79fdd790ee3aace9e38aaeb01a3080 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -18,6 +18,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -45,7 +46,6 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { // Implements argmax on CPU. This is called by an XLA custom call, set up by // index_ops.cc. -extern "C" void __attribute__((visibility("default"))) -argmax_float_2d_xla_impl(void* out, void** data) { +extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 74e3297dc3340d9e98e149065a738c3d2e73cf45..24a99f253d6dc8bb699fff587c363b12c227e821 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); +class OnesLikeOp : public XlaOpKernel { + public: + explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + + auto one = XlaHelpers::One(ctx->builder(), input_type(0)); + ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f15b354cb26d390352d866a8e827970f7c8b0c7f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +void SpaceToBatch(XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, DataType input_dtype, + const TensorShape& input_tensor_shape, + gtl::ArraySlice block_shape, + const xla::Literal& paddings) { + const int input_rank = input_tensor_shape.dims(); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + const int block_rank = block_shape.size(); + + OP_REQUIRES( + ctx, input_rank >= 1 + block_rank, + errors::InvalidArgument("input rank should be >= ", 1 + block_rank, + " instead of ", input_rank)); + gtl::ArraySlice remainder_shape(input_shape); + remainder_shape.remove_prefix(1 + block_rank); + + OP_REQUIRES( + ctx, + xla::ShapeUtil::Rank(paddings.shape()) == 2 && + block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && + 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), + errors::InvalidArgument("paddings should have shape [", block_rank, + ", 2] instead of ", + xla::ShapeUtil::HumanString(paddings.shape()))); + + xla::ComputationBuilder* b = ctx->builder(); + + // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the + // input according to `paddings` to produce `padded` of shape `padded_shape`. + xla::PaddingConfig padding_config; + std::vector padded_shape(input_shape.begin(), input_shape.end()); + int64 block_num_elems = 1LL; + padding_config.add_dimensions(); // Don't pad the batch dimension. + for (int i = 0; i < block_rank; ++i) { + auto* dim = padding_config.add_dimensions(); + int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); + int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, + errors::InvalidArgument("Paddings must be non-negative")); + dim->set_edge_padding_low(pad_start); + dim->set_edge_padding_high(pad_end); + padded_shape[1 + i] += pad_start + pad_end; + block_num_elems *= block_shape[i]; + } + // Don't pad the remainder dimensions. + for (int i = 0; i < remainder_shape.size(); ++i) { + padding_config.add_dimensions(); + } + OP_REQUIRES(ctx, block_num_elems > 0, + errors::InvalidArgument( + "The product of the block dimensions must be positive")); + + xla::ComputationDataHandle padded = + b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); + + // 2. Reshape `padded` to `reshaped_padded` of shape: + // + // [batch] + + // [padded_shape[1] / block_shape[0], + // block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1], + // block_shape[M-1]] + + // remaining_shape + const int64 batch_size = input_shape[0]; + std::vector reshaped_padded_shape(input_rank + block_rank); + reshaped_padded_shape[0] = batch_size; + for (int i = 0; i < block_rank; ++i) { + OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0, + errors::InvalidArgument("padded_shape[", 1 + i, + "]=", padded_shape[1 + i], + " is not divisible by block_shape[", i, + "]=", block_shape[i])); + + reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i]; + reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + reshaped_padded_shape.begin() + 1 + 2 * block_rank); + + xla::ComputationDataHandle reshaped_padded = + b->Reshape(padded, reshaped_padded_shape); + + // 3. Permute dimensions of `reshaped_padded` to produce + // `permuted_reshaped_padded` of shape: + // + // block_shape + + // [batch] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + std::vector permutation(reshaped_padded_shape.size()); + for (int i = 0; i < block_rank; ++i) { + permutation[i] = 1 + 2 * i + 1; + permutation[block_rank + 1 + i] = 1 + 2 * i; + } + permutation[block_rank] = 0; + std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), + 1 + block_rank * 2); + xla::ComputationDataHandle permuted_reshaped_padded = + b->Transpose(reshaped_padded, permutation); + + // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch * prod(block_shape)] + + // [padded_shape[1] / block_shape[0], + // ..., + // padded_shape[M] / block_shape[M-1]] + + // remaining_shape + // Determine the length of the prefix of block dims that can be combined + // into the batch dimension due to having no padding and block_shape=1. + std::vector output_shape(input_rank); + output_shape[0] = batch_size * block_num_elems; + for (int i = 0; i < block_rank; ++i) { + output_shape[1 + i] = padded_shape[1 + i] / block_shape[i]; + } + std::copy(remainder_shape.begin(), remainder_shape.end(), + output_shape.begin() + 1 + block_rank); + + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped_padded, output_shape); + ctx->SetOutput(0, output); +} + +class SpaceToBatchNDOp : public XlaOpKernel { + public: + explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + std::vector block_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape)); + + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + block_shape, paddings); + } +}; +REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); + +class SpaceToBatchOp : public XlaOpKernel { + public: + explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::Literal paddings; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings)); + + SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0), + {block_size_, block_size_}, paddings); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 03e02299e33a4e2bf62e757b2092db35288b0bea..a6cac62ca4bcb7e2d1c722862208f673d0a2c86f 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,11 +77,9 @@ class StridedSliceOp : public XlaOpKernel { gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_end; + bool simple_strides = true; for (int i = 0; i < begin.size(); ++i) { - // TODO(phawkins): implement strides != 1 when b/30878775 is fixed. - OP_REQUIRES( - ctx, strides[i] == 1 || strides[i] == -1, - errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + simple_strides &= (std::abs(strides[i]) == 1); if (strides[i] > 0) { slice_begin.push_back(begin[i]); slice_end.push_back(end[i]); @@ -99,6 +97,35 @@ class StridedSliceOp : public XlaOpKernel { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } + // If at least one of the strides is > 1 (or < -1) then use Slice + // to pull out each of the strided slices, and Concat to put them + // together again. + if (!simple_strides) { + // Re-adjust the begin and end now that the periphery has been + // sliced away. + for (int d = 0; d < strides.size(); ++d) { + slice_end[d] -= slice_begin[d]; + slice_begin[d] = 0; + } + + for (int d = 0; d < strides.size(); ++d) { + int64 stride = std::abs(strides[d]); + if (stride > 1) { + std::vector to_concat; + int64 end = slice_end[d]; + for (int64 i = 0; i < end; i += stride) { + slice_begin[d] = i; + slice_end[d] = i + 1; + to_concat.push_back( + ctx->builder()->Slice(slice, slice_begin, slice_end)); + } + slice = ctx->builder()->ConcatInDim(to_concat, d); + slice_begin[d] = 0; + slice_end[d] = to_concat.size(); + } + } + } + slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes()); ctx->SetOutput(0, slice); } @@ -219,5 +246,118 @@ class StridedSliceGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +class StridedSliceAssignOp : public XlaOpKernel { + public: + explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape final_shape; + gtl::InlinedVector begin; + gtl::InlinedVector end; + gtl::InlinedVector strides; + + xla::Literal begin_literal, end_literal, strides_literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal)); + + Tensor begin_tensor, end_tensor, strides_tensor; + OP_REQUIRES_OK( + ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor)); + OP_REQUIRES_OK(ctx, + LiteralToHostTensor(end_literal, index_type_, &end_tensor)); + OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, + &strides_tensor)); + + DataType lhs_type; + TensorShape lhs_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + + const TensorShape rhs_shape = ctx->InputShape(4); + + TensorShape dummy_processing_shape; + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape( + &dummy_processing_shape); + bool dummy = false; + OP_REQUIRES_OK( + ctx, ValidateStridedSliceOp( + &begin_tensor, &end_tensor, strides_tensor, + ShapeReadWriteFromTensorShape(&lhs_shape), begin_mask_, + end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, + &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy, + &dummy, &dummy, &begin, &end, &strides)); + + if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) { + // DynamicUpdateSlice does not allow 0-element updates. We should probably + // check that rhs_shape can be broadcast to final_shape, but that is + // probably better handled when implementing broadcasting more generally. + return; + } + + // TODO(aselle): This check is too strong, we only should need + // input_shape to be broadcastable to final_shape + OP_REQUIRES(ctx, final_shape == rhs_shape, + errors::Unimplemented( + "sliced l-value shape ", final_shape.DebugString(), + " does not match r-value shape ", rhs_shape.DebugString(), + ". Automatic broadcasting not yet implemented.")); + + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); + + xla::ComputationDataHandle rhs = ctx->Input(4); + + gtl::InlinedVector dimensions_to_reverse; + gtl::InlinedVector slice_begin, slice_dims; + for (int i = 0; i < begin.size(); ++i) { + // TODO(phawkins): implement strides != 1 + OP_REQUIRES( + ctx, strides[i] == 1 || strides[i] == -1, + errors::Unimplemented("Strides != 1 or -1 are not yet implemented")); + if (strides[i] > 0) { + slice_begin.push_back(begin[i]); + slice_dims.push_back(end[i] - begin[i]); + } else { + // Negative stride: swap begin and end, add 1 because the interval + // is semi-open, and mark the dimension to be reversed. + slice_begin.push_back(end[i] + 1); + slice_dims.push_back(begin[i] - end[i]); + dimensions_to_reverse.push_back(i); + } + } + + if (!dimensions_to_reverse.empty()) { + rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse); + } + rhs = ctx->builder()->Reshape(rhs, slice_dims); + + if (lhs_shape.dims() == 0) { + // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix + // and remove this workaround. + lhs = rhs; + } else { + lhs = ctx->builder()->DynamicUpdateSlice( + lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); + } + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + } + + private: + int32 begin_mask_, end_mask_; + int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + DataType index_type_; +}; + +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f1d81f871423b220c6859c1dedf79b1c36a43e65..ddd81cb490cd76065735a5b7e78d04fd76c05f82 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -165,6 +165,106 @@ class ResourceApplyAdagrad : public XlaOpKernel { }; REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad); +class ResourceApplyAdam : public XlaOpKernel { + public: + explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType var_type, m_type, v_type; + TensorShape var_shape, m_shape, v_shape; + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); + OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); + + OP_REQUIRES( + ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, + errors::InvalidArgument( + "Types of variable arguments to ResourceApplyRMSProp must match: ", + DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", + DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + + TensorShape beta1_power_shape = ctx->InputShape(3); + TensorShape beta2_power_shape = ctx->InputShape(4); + TensorShape lr_shape = ctx->InputShape(5); + TensorShape beta1_shape = ctx->InputShape(6); + TensorShape beta2_shape = ctx->InputShape(7); + TensorShape epsilon_shape = ctx->InputShape(8); + TensorShape grad_shape = ctx->InputShape(9); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power_shape), + errors::InvalidArgument("beta2_power is not a scalar: ", + beta2_power_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar : ", + lr_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2_shape.DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape), + errors::InvalidArgument("var and m do not have the same shape", + var_shape.DebugString(), " ", + m_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape), + errors::InvalidArgument("var and v do not have the same shape", + var_shape.DebugString(), " ", + v_shape.DebugString())); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); + xla::ComputationDataHandle beta1_power = ctx->Input(3); + xla::ComputationDataHandle beta2_power = ctx->Input(4); + xla::ComputationDataHandle lr = ctx->Input(5); + xla::ComputationDataHandle beta1 = ctx->Input(6); + xla::ComputationDataHandle beta2 = ctx->Input(7); + xla::ComputationDataHandle epsilon = ctx->Input(8); + xla::ComputationDataHandle grad = ctx->Input(9); + + // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) + // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t + // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t + // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + + xla::ComputationDataHandle alpha = + b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), + b->Sub(one, beta1_power)); + m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); + v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2))); + var = + b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon))); + + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v)); + } + + private: + DataType dtype_; +}; +REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam); + class ResourceApplyRMSProp : public XlaOpKernel { public: explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc index ce25d631271b54a36078cd0d3ac4d318d58db9fa..2b0834fe7b6c4d2199267dbe0ec1f7c2785aa9c7 100644 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ b/tensorflow/compiler/tf2xla/str_util.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace str_util { -void ReplaceAll(string* text, StringPiece from, StringPiece to) { +static void ReplaceAll(string* text, StringPiece from, StringPiece to) { size_t pos = 0; while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { text->replace(pos, from.size(), to.data(), to.size()); diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h index 4920b1a4d4875192d6f06988b810ad388bc6293b..51f25009d7003db0d72296619a469ecbbbb1808d 100644 --- a/tensorflow/compiler/tf2xla/str_util.h +++ b/tensorflow/compiler/tf2xla/str_util.h @@ -29,10 +29,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -// Replace all non-overlapping occurrences of from with to in-place in text. If -// from is empty, it matches at the beginning of the text and after every byte. -void ReplaceAll(string* text, StringPiece from, StringPiece to); - // Replace all non-overlapping occurrences of the given (from,to) pairs in-place // in text. If from is empty, it matches at the beginning of the text and after // every byte. Each (from,to) replacement pair is processed in the order it is diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc index f992007a34532157f86c90c717a5e24c3923f22d..8817f6902a8e58e796ca5240a9a24d7506d38793 100644 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ b/tensorflow/compiler/tf2xla/str_util_test.cc @@ -25,36 +25,6 @@ limitations under the License. namespace tensorflow { namespace str_util { -class ReplaceAllTest : public ::testing::Test { - protected: - void ExpectReplaceAll(string text, StringPiece from, StringPiece to, - StringPiece want) { - ReplaceAll(&text, from, to); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllTest, Simple) { - ExpectReplaceAll("", "", "", ""); - ExpectReplaceAll("", "", "X", "X"); - ExpectReplaceAll("", "", "XYZ", "XYZ"); - ExpectReplaceAll("banana", "", "", "banana"); - ExpectReplaceAll("banana", "", "_", "_b_a_n_a_n_a_"); - ExpectReplaceAll("banana", "", "__", "__b__a__n__a__n__a__"); - ExpectReplaceAll("banana", "a", "a", "banana"); - ExpectReplaceAll("banana", "a", "", "bnn"); - ExpectReplaceAll("banana", "a", "X", "bXnXnX"); - ExpectReplaceAll("banana", "a", "XX", "bXXnXXnXX"); - ExpectReplaceAll("banana", "an", "an", "banana"); - ExpectReplaceAll("banana", "an", "", "ba"); - ExpectReplaceAll("banana", "an", "X", "bXXa"); - ExpectReplaceAll("banana", "an", "XY", "bXYXYa"); - ExpectReplaceAll("banana", "an", "XYZ", "bXYZXYZa"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "X", "foo X baz X"); - ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "ABCDEFGHIJKLMNOP", - "foo ABCDEFGHIJKLMNOP baz ABCDEFGHIJKLMNOP"); -} - class ReplaceAllPairsTest : public ::testing::Test { protected: void ExpectReplaceAllPairs( diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d86e741b69e08652bac2dd7b5295c8ab2d94433a..362a1018955f9b6adbdea5ba718b81e9a2389957 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, options, Device::BuildDeviceAttributes( "", type, Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type())), - cpu_allocator()), + strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ba975d617dcd52de74830b3e69446c752fce1fcb..f8a9c5e9bc6f9ce778594209c9f974328cdb4b8f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -57,11 +57,38 @@ Status CheckSignature(const DataTypeVector& types, } // namespace +bool XlaCompiler::Argument::operator==( + const XlaCompiler::Argument& other) const { + if (std::tie(kind, type, shape, name) != + std::tie(other.kind, other.type, other.shape, other.name)) { + return false; + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(std::move(options)), + initialization_status_(Status::OK()), next_step_id_(1), - device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), - device_mgr_({device_}) {} + device_( + new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_mgr_({device_}) { + // We no longer need the device_type. + options_.device_type = nullptr; + + if (options_.populate_resource_manager) { + initialization_status_ = + (*options_.populate_resource_manager)(device_->resource_manager()); + } + + flib_runtime_.reset(NewFunctionLibraryRuntime( + &device_mgr_, Env::Default(), device_, options.graph_def_version, + options.flib_def, OptimizerOptions(), + nullptr /* custom_kernel_creator */)); +} XlaCompiler::~XlaCompiler() = default; @@ -70,37 +97,35 @@ int64 XlaCompiler::NextStepId() { return next_step_id_++; } -// Prunes any nodes from a function that are not dependencies of the _Retval -// nodes. Used to prune stateful ops from within a function body, such as -// variable initializers, that should not be executed unless requested. -static void PruneUnreachableNodes(Graph* graph) { - std::unordered_set nodes; - for (Node* node : graph->nodes()) { - if (node->type_string() == "_Retval" || - StringPiece(node->type_string()).ends_with("Send")) { - nodes.insert(node); - } - } - PruneForReverseReachability(graph, nodes); +uint64 XlaCompiler::SignatureHash::operator()( + const std::pair>& signature) const { + return std::hash()(signature.first); } Status XlaCompiler::CompileFunction( - FunctionLibraryRuntime* flr, const NameAttrList& function, + const XlaCompiler::CompileOptions& options, const NameAttrList& function, const std::vector& args, XlaCompiler::CompilationResult* result) { - const string function_id = Canonicalize(function.name(), function.attr()); + const string function_id = + Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; + auto it = cache_.find({function_id, args}); + if (it != cache_.end()) { + *result = it->second; + return Status::OK(); + } + FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR( - flr->Instantiate(function.name(), function.attr(), &handle)); + TF_RETURN_IF_ERROR(flib_runtime_->Instantiate( + function.name(), AttrSlice(&function.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); + const FunctionBody* fbody = flib_runtime_->GetFunctionBody(handle); CHECK(fbody); TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); - std::unique_ptr graph(new Graph(flr->GetFunctionLibraryDefinition())); + std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); if (VLOG_IS_ON(1)) { @@ -109,11 +134,13 @@ Status XlaCompiler::CompileFunction( } // Optimize the graph before running the compiler. - // TODO(pbar): The constant folder currently does not simplify int32 - // operations for devices other than CPU. OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - OptimizeGraph(flr, &graph); + optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), + /*device=*/nullptr, &graph); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( @@ -123,9 +150,10 @@ Status XlaCompiler::CompileFunction( VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( - CompileGraph(function_id, std::move(graph), flr, args, result)); + CompileGraph(options, function_id, std::move(graph), args, result)); VLOG(1) << "===================================================="; + cache_[{function_id, args}] = *result; return Status::OK(); } @@ -152,7 +180,7 @@ Status XlaCompiler::BuildExecutable( build_options.set_has_hybrid_result( options_.local_executable_has_hybrid_result); - auto compile_result = local_client->Compile(result.computation, + auto compile_result = local_client->Compile(*result.computation, argument_layouts, build_options); if (!compile_result.ok()) { return compile_result.status(); @@ -372,44 +400,45 @@ Status BuildComputation( } // namespace -Status XlaCompiler::CompileGraph(string const& name, +Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, + string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flib, const std::vector& args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + // Report the error here if initialization failed. + TF_RETURN_IF_ERROR(initialization_status_); + xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, options_.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options_.use_tuple_arg; + result->tuple_arg = options.use_tuple_arg; std::vector context_args; - TF_RETURN_IF_ERROR(BuildArguments(args, options_.use_tuple_arg, &builder, + TF_RETURN_IF_ERROR(BuildArguments(args, options.use_tuple_arg, &builder, &context_args, &result->input_mapping, &result->xla_input_shapes)); context->set_args(std::move(context_args)); - if (options_.prune_unreachable_nodes) { - PruneUnreachableNodes(graph.get()); - } - - TF_RETURN_IF_ERROR( - ExecuteGraph(context, std::move(graph), device_, flib, NextStepId())); + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, + flib_runtime_.get(), NextStepId())); int num_nonconst_outputs; + result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( context->retvals(), context->variables(), context->has_side_effects(), - options_.return_updated_values_for_all_variables, &builder, - &result->computation, &num_nonconst_outputs, &result->variable_updates)); + options.return_updated_values_for_all_variables, &builder, + result->computation.get(), &num_nonconst_outputs, + &result->variable_updates)); result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options_.use_tuple_arg && result->requires_runtime_context)); + CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -425,19 +454,21 @@ Status XlaCompiler::CompileGraph(string const& name, } } - if (result->computation.IsNull()) { + if (result->computation->IsNull()) { return Status::OK(); } // Compute the output shapes, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(result->computation); + auto computation_shape = client()->GetComputationShape(*result->computation); if (!computation_shape.ok()) { return computation_shape.status(); } result->xla_output_shape.Swap( computation_shape.ValueOrDie()->mutable_result()); + VLOG(2) << "XLA output shape: " + << xla::ShapeUtil::HumanString(result->xla_output_shape); auto num_computation_outputs = (xla::ShapeUtil::IsTuple(result->xla_output_shape)) @@ -463,10 +494,10 @@ Status XlaCompiler::CompileGraph(string const& name, i < context->retvals().size(); ++i) { const XlaContext::HandleOrConstant& retval = context->retvals()[i]; if (!retval.is_constant) { - CHECK_LT(computation_output, num_nonconst_outputs); + CHECK_LT(computation_output, num_computation_outputs); OutputDescription& output = result->outputs[i]; output.is_constant = false; - if (num_nonconst_outputs > 1) { + if (num_computation_outputs > 1) { output.shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 3ed920521b229c1ddac9ffffe924066624f3de5c..15f723ad782376b99ae7d72a5f15129e7880e9b1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -112,6 +113,8 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; + + bool operator==(const Argument& other) const; }; struct OutputDescription { @@ -172,15 +175,22 @@ class XlaCompiler { // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. - xla::Computation computation; + std::shared_ptr computation; }; struct Options { - // Name of the compilation device to use. - DeviceType device_type = DeviceType(""); + // Name of the compilation device to use. Needs to be live only during + // XlaCompiler's constructor. + const DeviceType* device_type = nullptr; xla::Client* client = nullptr; + // Function library in which to find function definitions. Must be non-null. + const FunctionLibraryDefinition* flib_def = nullptr; + + // The graph def version to be compiled. + int graph_def_version = TF_GRAPH_DEF_VERSION; + // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed // to the computation. @@ -198,6 +208,19 @@ class XlaCompiler { // computation. bool resolve_compile_time_constants = true; + // If not nullptr, populate_resource_manager is called with the + // compilation device's resource manager when the compilation + // device is created, and can be used to create metadata objects + // that can be accessed by XLA op kernels. + std::function* populate_resource_manager = nullptr; + }; + + explicit XlaCompiler(Options options); + ~XlaCompiler(); + + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { // If `use_tuple_arg` is true, a single tuple parameter will be used for all // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; @@ -208,17 +231,8 @@ class XlaCompiler { // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. bool return_updated_values_for_all_variables = false; - - // If 'prune_unreachable_nodes' is true, then nodes that are not - // dependencies of graph's _Retval nodes will be pruned before compilation. - // This is useful to prune stateful operators that should not be executed - // from a function body. - bool prune_unreachable_nodes = false; }; - explicit XlaCompiler(Options options); - ~XlaCompiler(); - // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. // `args` describes the arguments to the function, each of which must either // be a runtime-parameter to the XLA computation, a compile-time constant, or @@ -229,7 +243,7 @@ class XlaCompiler { // arguments are returned as host memory tensors in the output list and are // not included in the XLA computation's outputs. The XLA computation is // null if there are no data-dependent outputs and no side effects. - Status CompileFunction(FunctionLibraryRuntime* flr, + Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, const std::vector& args, CompilationResult* result); @@ -237,8 +251,8 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - Status CompileGraph(string const& name, std::unique_ptr graph, - FunctionLibraryRuntime* flr, + Status CompileGraph(const CompileOptions& options, string const& name, + std::unique_ptr graph, const std::vector& args, CompilationResult* result); @@ -247,9 +261,11 @@ class XlaCompiler { Status BuildExecutable(const CompilationResult& result, std::unique_ptr* executable); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } XlaCompilationDevice* device() const { return device_; } const DeviceMgr* device_mgr() const { return &device_mgr_; } + FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_.get(); } // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -260,6 +276,9 @@ class XlaCompiler { private: Options options_; + // Status set to non-OK in the constructor if initialization fails. + Status initialization_status_; + // Returns the next step sequence number. int64 NextStepId(); @@ -271,6 +290,17 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + std::unique_ptr flib_runtime_; + + struct SignatureHash { + uint64 operator()( + const std::pair>& signature) const; + }; + + std::unordered_map>, + CompilationResult, SignatureHash> + cache_; + std::unordered_map channels_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aa809f85a150cbff1b4504fced467c21e0314f6f..58d74057d101cdef89fca24ec6c0858291d825fa 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -33,8 +35,69 @@ limitations under the License. namespace tensorflow { namespace { +// Helper class to test the ability to pass resources through to XLA +// compiled kernels. +class DummyResourceForTest : public ResourceBase { + public: + string DebugString() override { return "dummy"; } + void Increment() { ++value_; } + int Get() { return value_; } + + private: + int value_ = 0; +}; + +class DummyReadResourceOp : public XlaOpKernel { + public: + explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ResourceMgr* rm = ctx->op_kernel_context()->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + DummyResourceForTest* dummy; + OP_REQUIRES_OK(ctx, rm->Lookup( + rm->default_container(), "dummy", &dummy)); + dummy->Increment(); + dummy->Unref(); + + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +class DummyReadResourceCC { + public: + DummyReadResourceCC(const Scope& scope, const Input& value) { + if (!scope.ok()) return; + auto _value = ops::AsNodeOut(scope, value); + if (!scope.ok()) return; + Node* ret; + const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource"); + auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value); + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return; + this->output_ = Output(ret, 0); + } + Node* node() const { return output_.node(); } + + Output output_; +}; + +REGISTER_OP("DummyReadResource") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); + class XlaCompilerTest : public ::testing::Test { protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -46,19 +109,13 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.device_type = &cpu_device_type_; options.client = client_; + options.flib_def = flib_def_.get(); return options; } - std::unique_ptr BuildFunctionLibraryRuntime( - const XlaCompiler& compiler) { - return std::unique_ptr(NewFunctionLibraryRuntime( - compiler.device_mgr(), /*env=*/nullptr, compiler.device(), - TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions(), - /*custom_kernel_creator=*/nullptr)); - } - + DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -66,15 +123,15 @@ class XlaCompilerTest : public ::testing::Test { // Tests compilation of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph(new Graph(OpRegistry::Global())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("add", std::move(graph), flr.get(), + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), /*args=*/{}, &result)); // No computation should be generated. - EXPECT_EQ(0, result.computation.handle().handle()); + EXPECT_EQ(0, result.computation->handle().handle()); } // Tests compilation and execution of a graph that adds two tensors. @@ -99,11 +156,10 @@ TEST_F(XlaCompilerTest, Simple) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); - auto flr = BuildFunctionLibraryRuntime(compiler); XlaCompiler::CompilationResult result; - TF_ASSERT_OK( - compiler.CompileGraph("add", std::move(graph), flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -117,7 +173,7 @@ TEST_F(XlaCompilerTest, Simple) { std::unique_ptr actual = client_ - ->Execute(result.computation, {param0_data.get(), param1_data.get()}) + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -152,14 +208,14 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = true; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, + &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_TRUE(result.outputs[0].is_constant); @@ -174,7 +230,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -189,14 +245,14 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { XlaCompiler::Options options = DefaultOptions(); options.resolve_compile_time_constants = false; XlaCompiler compiler(options); - auto flr = BuildFunctionLibraryRuntime(compiler); std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph("constants", std::move(graph_copy), - flr.get(), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "constants", std::move(graph_copy), args, + &result)); ASSERT_EQ(2, result.outputs.size()); EXPECT_FALSE(result.outputs[0].is_constant); @@ -209,7 +265,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_->Execute(result.computation, {param0_data.get()}) + client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); @@ -224,5 +280,44 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { } } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, ResourceManager) { + // Builds a graph that calls the dummy resource Op. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = DummyReadResourceCC(scope.WithOpName("B"), a); + auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + DummyResourceForTest* resource = new DummyResourceForTest(); + + // Compiles the graph. + auto options = DefaultOptions(); + std::function populate_function = + [resource](ResourceMgr* rm) { + resource->Ref(); + return rm->Create(rm->default_container(), "dummy", resource); + }; + options.populate_resource_manager = &populate_function; + XlaCompiler compiler(options); + + EXPECT_EQ(0, resource->Get()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &result)); + + EXPECT_EQ(1, resource->Get()); + + resource->Unref(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 57d946509b65a6d5ebf013857cf52297559431ea..3592680303c95e310b8da85294ed961a5350e09c 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -183,9 +184,14 @@ const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) { xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); - auto one = b.ConstantLiteral(xla::LiteralUtil::One(xla_type)); - auto minus_one = b.Neg(one); - b.Div(one, b.Add(b.Exp(b.Mul(x, minus_one)), one)); + // Clamp the inputs to the range [-18, 18] since anything outside + // this range is 0.0f or 1.0f in single-precision. We must clamp the range + // of x to avoid incorrect outputs due to fast-math optimizations for large + // negative x. + x = b.Clamp(XlaHelpers::IntegerLiteral(&b, type, -18), x, + XlaHelpers::IntegerLiteral(&b, type, 18)); + auto one = XlaHelpers::One(&b, type); + b.Div(one, b.Add(b.Exp(b.Neg(x)), one)); return b.Build().ConsumeValueOrDie(); }); } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 10d8b67bbd2d0e897e3ca55e584f575448a3a4fd..f060f8f2f178b2bc56caf7a3df9df32c8a407473 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -89,7 +90,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: - LOG(FATAL) << "f16 literals not yet implemented"; + literal = + *xla::LiteralUtil::CreateR0(static_cast(value)); + break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; case xla::OPAQUE: @@ -107,6 +110,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { + case xla::F16: + return b->ConstantR0(static_cast(value)); + break; case xla::F32: return b->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h index cd773d64ed4154aa2a05ac2d15e9358614239b1f..dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a 100644 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h @@ -23,7 +23,7 @@ limitations under the License. // actually used. E.g. some ahead-of-time compiled computations don't need a // thread pool. namespace Eigen { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index dc5a342bcdd2cc3e47e873c4e495730eb4d0fcde..4de69ee43c355621c429bcd1ba3f4d623e9b0d78 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -186,6 +186,31 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, + xla::Literal* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); + switch (literal.shape().element_type()) { + case xla::S32: + out->Clear(); + *out->mutable_shape() = literal.shape(); + out->mutable_shape()->set_element_type(xla::S64); + for (int32 x : literal.s32s()) { + out->add_s64s(x); + } + return Status::OK(); + + case xla::S64: + out->Swap(&literal); + return Status::OK(); + + default: + return errors::InvalidArgument( + "Invalid argument to ConstantInputAsInt64Literal: ", + xla::ShapeUtil::HumanString(literal.shape())); + } +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { @@ -332,6 +357,7 @@ void XlaOpKernelContext::SetVariableOutput(int index, int variable_id) { Status XlaOpKernelContext::AssignVariable( int index, DataType type, const xla::ComputationDataHandle& handle) { + TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression = @@ -354,6 +380,10 @@ void XlaOpKernelContext::SetOpHasSideEffects() { XlaContext::Get(context_).AddSideEffects(); } +XlaCompiler* XlaOpKernelContext::compiler() const { + return XlaContext::Get(context_).compiler(); +} + void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } void XlaOpKernelContext::CtxFailureWithWarning(Status s) { context_->CtxFailureWithWarning(s); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index d214879e3cc9a86e6499d0afa68f572b6c6a3a15..0a8a9284186e5b72a8a376ad159eb7b2482699c5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -109,6 +110,9 @@ class XlaOpKernelContext { // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + // Converts a constant int32 or int64 Tensor into an xla int64 Literal. + Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -182,6 +186,10 @@ class XlaOpKernelContext { // Returns the underlying OpKernelContext. Use rarely. OpKernelContext* op_kernel_context() const { return context_; } + // Returns the XlaCompiler that is performing the compilation. Used for, e.g., + // While to compile nested computations. + XlaCompiler* compiler() const; + // TODO(phawkins): find a better home for these helpers. // Get an XLA lambda to compute Max. This is cached in the diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 5b895bfdf60d2206316af1b023e5ed91e7eec424..13fdfc3b0c82e3d0018c72eebaaf7fa313111648 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -167,6 +167,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { !backend.second.op_filter(kdef.get())) { continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op.first; registry.kernel_registrars_.emplace_back( new kernel_factory::OpKernelRegistrar( new KernelDef(*kdef), "XlaJitOp", op.second->factory)); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index e73a29ddee1cd5b02453524618d5f3623b331cf8..6b424d23092b138e9f4d32062d575f17ab4791cb 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -6,6 +6,7 @@ package_group( name = "friends", packages = [ "//tensorflow/compiler/...", + "//tensorflow/contrib/xla_tf_graph/...", ], ) @@ -16,6 +17,7 @@ package_group( ], ) +load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") # Filegroup used to collect source files for dependency checking. @@ -43,11 +45,42 @@ xla_proto_library( ], ) +# This is a headers target that extra XLA devices can use to prevent +# circular dependencies. Devices that are compiled as separate shared +# objects can also use it to prevent linking of library code. +cc_header_only_library( + name = "xla_headers_lib", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:stream_executor_headers_lib", + ], +) + +cc_library( + name = "test", + testonly = 1, + hdrs = ["test.h"], + visibility = [":friends"], + deps = [ + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + ], +) + cc_library( name = "types", hdrs = ["types.h"], visibility = [":friends"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:lib", + "//third_party/eigen3", + ], ) cc_library( @@ -80,6 +113,7 @@ cc_test( deps = [ ":status_macros", ":statusor", + ":test", ":test_helpers", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -115,6 +149,7 @@ cc_test( srcs = ["statusor_test.cc"], deps = [ ":statusor", + ":test", ":types", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -157,9 +192,9 @@ cc_test( name = "util_test", srcs = ["util_test.cc"], deps = [ + ":test", ":types", ":util", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -198,10 +233,11 @@ cc_test( srcs = ["shape_util_test.cc"], deps = [ ":shape_util", + ":test", ":test_helpers", ":types", ":util", - "//tensorflow/core:test", + ":xla_data_proto", "//tensorflow/core:test_main", ], ) @@ -211,6 +247,7 @@ cc_test( srcs = ["layout_util_test.cc"], deps = [ ":shape_util", + ":test", ":test_helpers", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/core:test", @@ -223,9 +260,9 @@ cc_test( srcs = ["index_util_test.cc"], deps = [ ":shape_util", + ":test", ":test_helpers", ":xla_data_proto", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -240,6 +277,7 @@ cc_library( ":array3d", ":array4d", ":shape_util", + ":status_macros", ":types", ":util", ":xla_data_proto", @@ -255,7 +293,7 @@ cc_test( ":array4d", ":literal_util", ":shape_util", - ":test_helpers", + ":test", ":types", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -270,7 +308,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":util", - ":xla_data_proto", "//tensorflow/core:lib", ], ) @@ -303,7 +340,7 @@ cc_test( srcs = ["array2d_test.cc"], deps = [ ":array2d", - "//tensorflow/core:test", + ":test", "//tensorflow/core:test_main", ], ) @@ -323,8 +360,8 @@ cc_test( srcs = ["array3d_test.cc"], deps = [ ":array3d", + ":test", ":types", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -345,8 +382,8 @@ cc_test( srcs = ["array4d_test.cc"], deps = [ ":array4d", + ":test", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -378,7 +415,6 @@ cc_library( cc_library( name = "test_helpers", testonly = 1, - srcs = ["test_helpers.cc"], hdrs = ["test_helpers.h"], visibility = [":internal"], deps = [ @@ -414,11 +450,11 @@ cc_test( deps = [ ":literal_util", ":shape_util", + ":test", ":text_literal_reader", ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -443,6 +479,7 @@ cc_test( srcs = ["text_literal_writer_test.cc"], deps = [ ":literal_util", + ":test", ":test_helpers", ":text_literal_writer", ":types", @@ -471,8 +508,8 @@ cc_test( deps = [ ":shape_tree", ":shape_util", + ":test", ":xla_data_proto", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -530,11 +567,11 @@ cc_test( ":array4d", ":literal_util", ":reference_util", + ":test", ":util", ":xla_data_proto", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index f885821210eb68dfb599303830c814c309e0a24d..593084a0c111690d9e239ed5837f6f0c6c713048 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -45,11 +45,15 @@ class Array2D { // Creates an array of dimensions n1 x n2, uninitialized values. Array2D(const int64 n1, const int64 n2) - : n1_(n1), n2_(n2), values_(n1 * n2) {} + : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { + Fill(T()); + } // Creates an array of dimensions n1 x n2, initialized to value. Array2D(const int64 n1, const int64 n2, const T value) - : n1_(n1), n2_(n2), values_(n1 * n2, value) {} + : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { + Fill(value); + } // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension; the inner is the second dimension. @@ -65,16 +69,30 @@ class Array2D { } } - T& operator()(const int64 n1, const int64 n2) { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - return values_[n1 * n2_ + n2]; + Array2D(const Array2D& other) : Array2D(other.n1(), other.n2()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array2D& operator=(const Array2D& other) { + n1_ = other.n1(); + n2_ = other.n2(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + T& operator()(const int64 i1, const int64 i2) { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + return values_[i1 * n2_ + i2]; } - const T& operator()(const int64 n1, const int64 n2) const { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - return values_[n1 * n2_ + n2]; + const T& operator()(const int64 i1, const int64 i2) const { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + return values_[i1 * n2_ + i2]; } // Access to the array's dimensions. height() and width() provide the @@ -84,15 +102,15 @@ class Array2D { int64 n2() const { return n2_; } int64 height() const { return n1_; } int64 width() const { return n2_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return n1_ * n2_; } // Low-level accessor for stuff like memcmp, handle with care. Returns pointer // to the underlying storage of the array (similarly to std::vector::data()). - T* data() const { return const_cast(this)->values_.data(); } + T* data() const { return const_cast(this)->values_.get(); } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Applies f to all cells in this array, in row-major order. @@ -124,8 +142,8 @@ class Array2D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -150,7 +168,7 @@ class Array2D { private: int64 n1_; int64 n2_; - std::vector values_; + std::unique_ptr values_; }; // Returns a linspace-populated Array2D in the range [from, to] (inclusive) diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc index ac107b1c0d426c676629762dbc8191c74e2e1c7e..795d50ca5b56a60c34279a33e65aa635a65fa5ec 100644 --- a/tensorflow/compiler/xla/array2d_test.cc +++ b/tensorflow/compiler/xla/array2d_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/platform/test.h" +#include "tensorflow/compiler/xla/test.h" namespace xla { namespace { @@ -84,6 +84,17 @@ TEST(Array2dTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(Array2dTest, IndexingReadWriteBool) { + Array2D arr = {{false, true, false}, {true, true, false}}; + + EXPECT_EQ(arr(1, 1), true); + EXPECT_EQ(arr(1, 2), false); + arr(1, 1) = false; + arr(1, 2) = true; + EXPECT_EQ(arr(1, 1), false); + EXPECT_EQ(arr(1, 2), true); +} + TEST(Array2dTest, Fill) { Array2D fullof7(2, 3, 7); for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index 654af8f03074f30dd1561db412ad36f43a33aab9..124ccd1975b3a9ab047e9bbbfb38921fe7386fe4 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -20,9 +20,9 @@ limitations under the License. #include #include #include +#include #include #include -#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -39,11 +39,15 @@ class Array3D { public: // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) - : n1_(n1), n2_(n2), n3_(n3), values_(n1 * n2 * n3) {} + : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { + Fill(T()); + } // Creates an array of dimensions n1 x n2 x n3, initialized to value. Array3D(const int64 n1, const int64 n2, const int64 n3, const T value) - : n1_(n1), n2_(n2), n3_(n3), values_(n1 * n2 * n3, value) {} + : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { + Fill(value); + } // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. @@ -69,34 +73,50 @@ class Array3D { } } - T& operator()(const int64 n1, const int64 n2, const int64 n3) { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - CHECK_LT(n3, n3_); - return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + Array3D(const Array3D& other) + : Array3D(other.n1(), other.n2(), other.n3()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array3D& operator=(const Array3D& other) { + n1_ = other.n1(); + n2_ = other.n2(); + n3_ = other.n3(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + T& operator()(const int64 i1, const int64 i2, const int64 i3) { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + CHECK_LT(i3, n3_); + return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; } - const T& operator()(const int64 n1, const int64 n2, const int64 n3) const { - CHECK_LT(n1, n1_); - CHECK_LT(n2, n2_); - CHECK_LT(n3, n3_); - return values_[n1 * n2_ * n3_ + n2 * n3_ + n3]; + const T& operator()(const int64 i1, const int64 i2, const int64 i3) const { + CHECK_LT(i1, n1_); + CHECK_LT(i2, n2_); + CHECK_LT(i3, n3_); + return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; } // Access to the array's dimensions. int64 n1() const { return n1_; } int64 n2() const { return n2_; } int64 n3() const { return n3_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return n1_ * n2_ * n3_; } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with sequentially increasing values. void FillIota(const T& value) { - std::iota(values_.begin(), values_.end(), value); + std::iota(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with random normal values with a mean of 0 and standard @@ -106,8 +126,8 @@ class Array3D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -115,7 +135,7 @@ class Array3D { int64 n1_; int64 n2_; int64 n3_; - std::vector values_; + std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc index fa4435dfc48edcd5b88230e7d2de21e29e269b7e..6b5f4b343b2113652758bbd5ce0fc803239c1266 100644 --- a/tensorflow/compiler/xla/array3d_test.cc +++ b/tensorflow/compiler/xla/array3d_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 199ad2baaeb7999349fd6bb201a476706bb12ce7..56b638d9782a6c9db5206c070d69c5b2b367313f 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -60,15 +61,15 @@ class Array4D { depth_(depth), height_(height), width_(width), - values_(planes * depth * height * width) {} + values_(new T[planes * depth * height * width]) { + Fill(T()); + } - // Creates a 4D array, initalized to value. + // Creates a 4D array, initialized to value. Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) - : planes_(planes), - depth_(depth), - height_(height), - width_(width), - values_(planes * depth * height * width, value) {} + : Array4D(planes, depth, height, width) { + Fill(value); + } // Creates a 4D array, filled with values. // @@ -111,6 +112,23 @@ class Array4D { } } + Array4D(const Array4D& other) + : Array4D(other.planes(), other.depth(), other.height(), other.width()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array4D& operator=(const Array4D& other) { + planes_ = other.planes(); + depth_ = other.depth(); + height_ = other.height(); + width_ = other.width(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + T& operator()(int64 plane, int64 depth, int64 height, int64 width) { CHECK_LT(plane, planes_); CHECK_LT(depth, depth_); @@ -135,24 +153,24 @@ class Array4D { int64 n3() const { return height_; } int64 n2() const { return depth_; } int64 n1() const { return planes_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return width_ * height_ * depth_ * planes_; } // Sets all the values in the array to values. template > void SetValues(const Container& container) { CHECK_EQ(std::distance(std::begin(container), std::end(container)), num_elements()); - values_.assign(std::begin(container), std::end(container)); + std::copy(std::begin(container), std::end(container), &values_[0]); } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with iota. void FillIota(const T& value) { - std::iota(values_.begin(), values_.end(), value); + std::iota(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with random variable with a deviation of value and a mean @@ -162,8 +180,8 @@ class Array4D { std::mt19937 g(seed); std::normal_distribution distribution(mean, static_cast(value)); - for (auto& v : values_) { - v = static_cast(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); } } @@ -268,7 +286,7 @@ class Array4D { int64 depth_; int64 height_; int64 width_; - std::vector values_; + std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 72ada467e515eff98a2e5845dc6a3714a770650e..3bc8148c911df0aeade364e4ac2e2ee828bacb53 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 3e9dfe2a922c913c528d586413c11e2da8cbdc39..2d96128e259da316a41e83bea221ae201ad88a13 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -99,6 +99,26 @@ cc_library( ], ) +cc_library( + name = "compile_only_client", + srcs = ["compile_only_client.cc"], + hdrs = ["compile_only_client.h"], + deps = [ + ":client", + ":computation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:compile_only_service", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:support", + ], +) + # This target is used to instantiate the XLA service in-process and create # a client for it. cc_library( @@ -106,12 +126,14 @@ cc_library( srcs = ["client_library.cc"], hdrs = ["client_library.h"], deps = [ + ":compile_only_client", ":local_client", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 93437023bc8956e449f828f5bf6dea7a6bff8610..8238261e1c90cadeda9005e437d684d3770bd67b 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -43,6 +43,16 @@ int LocalClientOptions::number_of_replicas() const { return number_of_replicas_; } +LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int LocalClientOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ ClientLibrary& ClientLibrary::Singleton() { static ClientLibrary* c = new ClientLibrary; return *c; @@ -69,22 +79,24 @@ ClientLibrary::~ClientLibrary() = default; TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - auto it = client_library.instances_.find(platform->id()); - if (it != client_library.instances_.end()) { + auto it = client_library.local_instances_.find(platform->id()); + if (it != client_library.local_instances_.end()) { return it->second->client.get(); } ServiceOptions service_options; service_options.set_platform(platform); service_options.set_number_of_replicas(replica_count); + service_options.set_intra_op_parallelism_threads( + options.intra_op_parallelism_threads()); - std::unique_ptr instance = MakeUnique(); + auto instance = MakeUnique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); instance->client = MakeUnique(instance->service.get()); LocalClient* cl = instance->client.get(); - client_library.instances_.insert( + client_library.local_instances_.insert( std::make_pair(platform->id(), std::move(instance))); return cl; } @@ -99,9 +111,35 @@ ClientLibrary::~ClientLibrary() = default; perftools::gputools::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); - auto it = client_library.instances_.find(platform->id()); - CHECK(it != client_library.instances_.end()); + auto it = client_library.local_instances_.find(platform->id()); + CHECK(it != client_library.local_instances_.end()); return it->second->service.get(); } +/* static */ StatusOr +ClientLibrary::GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform) { + ClientLibrary& client_library = Singleton(); + tensorflow::mutex_lock lock(client_library.service_mutex_); + + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + auto it = client_library.compile_only_instances_.find(platform->id()); + if (it != client_library.compile_only_instances_.end()) { + return it->second->client.get(); + } + + auto instance = MakeUnique(); + TF_ASSIGN_OR_RETURN(instance->service, + CompileOnlyService::NewService(platform)); + instance->client = MakeUnique(instance->service.get()); + CompileOnlyClient* cl = instance->client.get(); + + client_library.compile_only_instances_.insert( + std::make_pair(platform->id(), std::move(instance))); + return cl; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index 2bc319f9333368635690add017ad3d89947e2551..3ddd235d0efeeb78f49eafbf670d7c74a88960dd 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -26,7 +26,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/compile_only_client.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/statusor.h" @@ -51,9 +53,14 @@ class LocalClientOptions { LocalClientOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + LocalClientOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; class ClientLibrary { @@ -76,6 +83,13 @@ class ClientLibrary { // access user computations from client. static LocalService* GetXlaService(perftools::gputools::Platform* platform); + // Singleton constructor-or-accessor for compile-only clients. Arguments: + // + // platform : The platform the underlying XLA service should target. If + // null then default platform is used. + static StatusOr GetOrCreateCompileOnlyClient( + perftools::gputools::Platform* platform = nullptr); + private: // Returns the singleton instance of ClientLibrary. static ClientLibrary& Singleton(); @@ -90,10 +104,21 @@ class ClientLibrary { std::unique_ptr client; }; + struct CompileOnlyInstance { + // Service that is wrapped by the singleton client object. + std::unique_ptr service; + // Singleton client object. + std::unique_ptr client; + }; + tensorflow::mutex service_mutex_; // Guards the singleton creation state. std::unordered_map> - instances_ GUARDED_BY(service_mutex_); + local_instances_ GUARDED_BY(service_mutex_); + + std::unordered_map> + compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); }; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..2ff6f0b300f9e2cc776e60bb27a3952356657780 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/compile_only_client.h" + +#include "external/llvm/include/llvm/ADT/Triple.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +StatusOr>> +CompileOnlyClient::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector service_instances; + service_instances.reserve(computations.size()); + for (const AotComputationInstance& instance : computations) { + service_instances.push_back({}); + CompileOnlyService::AotComputationInstance& service_instance = + service_instances.back(); + TF_RET_CHECK(instance.computation != nullptr); + service_instance.computation = instance.computation->handle(); + service_instance.argument_layouts = instance.argument_layouts; + service_instance.result_layout = instance.result_layout; + } + return compiler_service_->CompileAheadOfTime(service_instances, options); +} + +int64 CompileOnlyClient::PointerSizeForTriple( + tensorflow::StringPiece target_triple) { + llvm::Triple triple( + llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); + if (triple.isArch64Bit()) { + return 8; + } else if (triple.isArch32Bit()) { + return 4; + } else { + CHECK(triple.isArch16Bit()); + return 2; + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h new file mode 100644 index 0000000000000000000000000000000000000000..5900048711384e0240a3cd502260eb388eb40f51 --- /dev/null +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/service/compile_only_service.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Client specialization for doing ahead-of-time compilation. This does +// not require (or attempt to instantiate) an execution-capable backend for the +// relevant platform. +class CompileOnlyClient : public Client { + public: + explicit CompileOnlyClient(CompileOnlyService* service) + : Client(service), compiler_service_(service) {} + + CompileOnlyClient(const CompileOnlyClient&) = delete; + void operator=(const CompileOnlyClient&) = delete; + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + const Computation* computation; + // Inform the compiler of the expected layout for arguments. + std::vector argument_layouts; + // Specifies the expected result layout. + const Shape* result_layout; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. The |options| parameter describes + // the target for which the compiler should emit code. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); + + // Returns the size of a pointer in bytes for a given triple. + static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + + private: + CompileOnlyService* compiler_service_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 88efd87d1cc3efd16750d6dfedb18159114b2cb6..22a70681468f16b12793274bf5ce72613534df42 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1229,8 +1229,7 @@ StatusOr ComputationBuilder::IsConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } return response.is_constant(); } @@ -1255,8 +1254,7 @@ StatusOr> ComputationBuilder::ComputeConstant( VLOG(2) << "done with request"; if (!s.ok()) { - NoteError(s); - return first_error_; + return s; } TF_RET_CHECK(response.output().handle() != 0); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 87ceb43d1fe6650e1d160f3099b883ea208d8aac..6af69eeec12dec0ea1303826859d4655cf92932e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -668,6 +668,14 @@ class ComputationBuilder { // then Build() should be used instead. Computation BuildAndNoteError(); + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // ComputationDataHandle and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + private: using PopulateLiteral = std::function; diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h index eb11d91034ba524f093ff80fa7cd0473e04eac2c..b7929357d06032b55c04bf0391f7fa703ee15f17 100644 --- a/tensorflow/compiler/xla/client/global_data.h +++ b/tensorflow/compiler/xla/client/global_data.h @@ -23,13 +23,15 @@ limitations under the License. namespace xla { -// Wraps a GlobalDataHandle with a lifetime. +// A GlobalData object represents a globally-accessible allocation of +// data in the associated XLA service. class GlobalData { public: // Gives ownership of the global data handle to this object. GlobalData(ServiceInterface* parent, GlobalDataHandle handle); - // Unregisters the wrapped handle. + // Unregisters the wrapped handle, which causes the service to + // deallocate the associated data. ~GlobalData(); const GlobalDataHandle& handle() const { return handle_; } diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index bfd14bc1c010353e3e473f10dd6c030cb0438648..02cf57e7632a2064e646d4dc441e3ec119053564 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,17 +176,24 @@ StatusOr> LocalExecutable::Run( TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; - Backend::StreamPtr stream; if (options.stream() == nullptr) { TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); + Backend::StreamPtr stream, + BorrowStreamForDevice(options.device_ordinal(), backend_)); actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { actual_options.set_allocator(backend_->memory_allocator()); } - ServiceExecutableRunOptions service_options(actual_options, - backend_->StreamBorrower()); + + // For local client execution on CPU backends: + // *) The thread pool used for eigen CPU ops is from + // ExecutableRunOptions.eigen_intra_op_thread_pool. + // *) The thread pool used for XLA CPU ops is from + // backend_->eigen_intra_op_thread_pool(). + ServiceExecutableRunOptions service_options( + actual_options, backend_->StreamBorrower(), + backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { return ExecuteAndDump(&service_options, arguments); @@ -253,46 +260,6 @@ StatusOr> LocalClient::AllocateBufferOnDevice( return std::unique_ptr(new GlobalData(local_service_, handle)); } -tensorflow::Status LocalClient::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - return local_service_->ResolveArguments(arguments, device_ordinal, - argument_ptrs); -} - -StatusOr>> -LocalClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AheadOfTimeComputationInstance& instance : computations) { - service_instances.push_back({}); - LocalService::AheadOfTimeComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return local_service_->CompileAheadOfTime(service_instances, options); -} - -int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { - llvm::Triple triple( - llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple))); - if (triple.isArch64Bit()) { - return 8; - } else if (triple.isArch32Bit()) { - return 4; - } else { - CHECK(triple.isArch16Bit()); - return 2; - } -} - se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 2c467efcea119b66ad08e0636eca0f1acec3a3b8..c903cd271125b44677f7bb191f100f6604f40bbc 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -56,7 +56,7 @@ class ExecutableBuildOptions { // If set, this specifies the layout of the result of the computation. If not // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accomodate tuple result shapes. A value of nullptr + // store the layout to accommodate tuple result shapes. A value of nullptr // indicates the option has not been set. ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; @@ -148,7 +148,7 @@ class LocalExecutable { const ExecutableBuildOptions& build_options_; }; -// An XLA service client object for use when the client and service run in +// An XLA Client specialization for use when the client and service run in // the same process. class LocalClient : public Client { public: @@ -158,14 +158,6 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // For an array of arguments held on the local service, validate - // that each is placed on the specified device_ordinal, and return - // the DeviceMemoryBase corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal on the local service. If // allocate_space_for_deep_copy, the buffer is large enough to hold @@ -182,30 +174,6 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - // - // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its - // own library. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options); - - // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); - // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index e3248d8e908b60c7e6f7224d25b963601c92f24a..76c0168f370ff1f0749759705b7ecff359a80341 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -118,17 +118,36 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices(const Shape& shape, - std::vector* indices) { - for (int64 dimno = indices->size() - 1; dimno >= 0; --dimno) { +/* static */ bool IndexUtil::BumpIndices( + const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { + for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); - if ((*indices)[dimno] + 1 < limit) { - (*indices)[dimno]++; - std::fill(indices->begin() + dimno + 1, indices->end(), 0); + if (indices[dimno] + 1 < limit) { + indices[dimno]++; + std::fill(indices.begin() + dimno + 1, indices.end(), 0); return true; } } return false; } +/* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, + int64 dimension) { + const Layout& layout = shape.layout(); + int64 pdim_size = layout.padded_dimensions_size(); + int64 stride = 1; + DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); + for (auto dim : layout.minor_to_major()) { + if (dim == dimension) { + break; + } + if (pdim_size == 0) { + stride *= shape.dimensions(dim); + } else { + stride *= layout.padded_dimensions(dim); + } + } + return stride; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 2d8753c3fe8fc05bdcdeaa18360ac5fe4a5e587b..c9838966a5b67397eb5fc4afe3ab9d98e82eb2b1 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -58,7 +58,16 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, std::vector* indices); + static bool BumpIndices(const Shape& shape, + tensorflow::gtl::MutableArraySlice indices); + + // Calculates the stride size (in number of elements, not byte size) of a + // given logical shape dimension (from 0 to rank-1). If available, padded + // dimensions are used. + // Example: + // GetDimensionStride(F32[5,8,10,4]{3,2,1,0}, 1) == + // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 + static int64 GetDimensionStride(const Shape& shape, int64 dimension); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 85259b33f0beea4b508c0d5c1f3a6294dda76813..7c4efdee484d9530a69b31cbe3a0d69a8a3cffa7 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -144,14 +143,11 @@ TEST(IndexUtilTest, BumpIndices2x2) { auto shape = ShapeUtil::MakeShape(S32, {2, 2}); std::vector indices = {0, 0}; EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{0, 1})); + EXPECT_THAT(indices, ::testing::ElementsAre(0, 1)); EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{1, 0})); + EXPECT_THAT(indices, ::testing::ElementsAre(1, 0)); EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); - EXPECT_MATCH(indices, - testing::VectorMatcher(std::vector{1, 1})); + EXPECT_THAT(indices, ::testing::ElementsAre(1, 1)); EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); } diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 531a6e03dad4759416f56465a6c582a06e440a5a..d3fcccff654fbbafa0b3c6a3d900123691f059fb 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,11 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/shape_util.h" - #include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -114,8 +113,8 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { @@ -133,8 +132,8 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("cannot copy layout from shape")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); } TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { @@ -145,9 +144,10 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); EXPECT_FALSE(status.ok()); - EXPECT_MATCH(status.error_message(), - testing::ContainsRegex("layout minor_to_major field contains .* " - "elements, but shape is rank")); + EXPECT_THAT( + status.error_message(), + ::testing::ContainsRegex("layout minor_to_major field contains .* " + "elements, but shape is rank")); } TEST_F(LayoutUtilTest, ClearLayoutTuple) { diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc index e79d3635095a0aacf20b37e586d2c9ac799cbe07..7d3ad60aea44bedcd5dccce91f1c4d24576f02b0 100644 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc @@ -38,7 +38,6 @@ static void AllocateFlags() { flags = new GpuCompilerFlags; flags->xla_gpu_embed_ir = false; flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas"; flag_list = new std::vector({ tensorflow::Flag( "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc index 8822f6f6107d3d9ff121c04e5904a7367c604be7..ba43a5919522ff783f450481c629d64613e1f8ab 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc @@ -36,10 +36,14 @@ static std::once_flag flags_init; static void AllocateFlags() { flags = new HloGraphDumperFlags; flags->xla_hlo_dump_graph_path = "/tmp/"; + flags->xla_hlo_dump_as_graphdef = false; flag_list = new std::vector({ tensorflow::Flag("xla_hlo_dump_graph_path", &flags->xla_hlo_dump_graph_path, "Path to write dumped HLO graphs to"), + tensorflow::Flag("xla_hlo_dump_as_graphdef", + &flags->xla_hlo_dump_as_graphdef, + "Dumps HLO graphs as tensorflow GraphDefs"), }); ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h index b6dfced87cae90c67bd46975a8e36eaef10b19e7..d0b4d092ff1003bc1df90c3d878feacf71a5aa21 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h @@ -34,6 +34,9 @@ void AppendHloGraphDumperFlags(std::vector* flag_list); // The values of flags associated with XLA's hlo_graph_dumper module. typedef struct { string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to + // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO + // graphs as DOT graph. + bool xla_hlo_dump_as_graphdef; } HloGraphDumperFlags; // Return a pointer to the HloGraphDumperFlags struct; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7091c324d14552d8b7603c3872d0ffc59771d8f7..ec4012a7036e19ec0c75e958b29511b2c5aa4713 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -16,12 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include +#include +#include #include #include #include #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -33,6 +36,137 @@ limitations under the License. namespace xla { +LiteralUtil::StrideConfig::StrideConfig( + const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions) + : dimensions(dimensions), + base(dimensions.size(), 0), + step(dimensions.size(), 1) { + if (!dimensions.empty()) { + // Selects the shape with the highest minor dimension as the one upon + // where to run the tight stride loop. + if (source_shape.layout().minor_to_major()[0] >= + dest_shape.layout().minor_to_major()[0]) { + minor_dimension = source_shape.layout().minor_to_major()[0]; + dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); + } else { + minor_dimension = dest_shape.layout().minor_to_major()[0]; + source_stride = + IndexUtil::GetDimensionStride(source_shape, minor_dimension); + } + minor_loop_size = dimensions[minor_dimension]; + step[minor_dimension] = minor_loop_size; + } +} + +/* static */ std::unique_ptr LiteralUtil::CreateFromShape( + const Shape& shape) { + auto literal = MakeUnique(); + *literal->mutable_shape() = shape; + Reserve(ShapeUtil::ElementsIn(literal->shape()), literal.get()); + return literal; +} + +/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions) { + return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); +} + +template +/* static */ Status LiteralUtil::CopyRange( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + const Shape& src_shape = src_literal.shape(); + const Shape& dest_shape = dest_literal->shape(); + tensorflow::gtl::ArraySlice src_data = GetArraySlice(src_literal); + tensorflow::gtl::MutableArraySlice dest_data = + GetMutableArraySlice(dest_literal); + + TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + // If any of the two shapes are scalars, we can just call the StridedCopy() + // directly, and we know we will be copying only one value. + TF_RET_CHECK(copy_size.empty()); + StridedCopy(dest_data, LinearIndex(*dest_literal, dest_base), 0, src_data, + LinearIndex(src_literal, src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(dest_shape)) { + TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + TF_RET_CHECK(src_base.size() == dest_base.size()); + TF_RET_CHECK(src_base.size() == copy_size.size()); + + // Scan the source from minor, stepping in copy size blocks, then within + // the index enumaration functor, do a strided copy advancing source index + // by one (walking through the minor dimension), and destination index by + // proper stride size at the matching dimension. + DimensionVector src_indexes(src_base.size(), 0); + DimensionVector dest_indexes(dest_base.size(), 0); + StrideConfig stride_config(src_shape, dest_shape, copy_size); + + auto copy_proc = [&](const std::vector& indexes) { + // Map from multi-dimensional index, to source index. + std::transform(indexes.begin(), indexes.end(), src_base.begin(), + src_indexes.begin(), std::plus()); + // Map from multi-dimensional index, to destination index. + std::transform(indexes.begin(), indexes.end(), dest_base.begin(), + dest_indexes.begin(), std::plus()); + + int64 src_index = LinearIndex(src_literal, src_indexes); + int64 dest_index = LinearIndex(*dest_literal, dest_indexes); + + StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, + src_index, stride_config.source_stride, + stride_config.minor_loop_size); + return true; + }; + + ShapeUtil::ForEachIndex(src_shape, stride_config.base, + stride_config.dimensions, stride_config.step, + copy_proc); + } + return Status::OK(); +} + +/* static */ Status LiteralUtil::Copy( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK( + ShapeUtil::SameElementType(src_literal.shape(), dest_literal->shape())); + switch (src_literal.shape().element_type()) { + case U32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case U64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case S32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case S64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F16: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F32: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case F64: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + case PRED: + return CopyRange(src_literal, src_base, dest_literal, dest_base, + copy_size); + default: + break; + } + return Unimplemented("Unhandled primitive type %d", + src_literal.shape().element_type()); +} + /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: @@ -47,6 +181,8 @@ namespace xla { return *LiteralUtil::CreateR0(0); case S64: return *LiteralUtil::CreateR0(0); + case F16: + return *LiteralUtil::CreateR0(static_cast(0.0f)); case F32: return *LiteralUtil::CreateR0(0); case F64: @@ -56,8 +192,6 @@ namespace xla { case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - LOG(FATAL) << "f16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 0"; case OPAQUE: @@ -91,7 +225,7 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -127,7 +261,8 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -163,7 +298,8 @@ namespace xla { case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - LOG(FATAL) << "f16 literals not yet implemented"; + return *LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -197,37 +333,16 @@ namespace xla { /* static */ std::unique_ptr LiteralUtil::Relayout( const Literal& original, const Layout& layout) { - // Note: if this were a performance bottleneck, we avoid cloning and just make - // an uninitialized array instead, since all values are clobbered below. std::unique_ptr result = CloneToUnique(original); *result->mutable_shape()->mutable_layout() = layout; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, float value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case S32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, int32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - case U32: - LiteralUtil::EachCell( - original, - [&](tensorflow::gtl::ArraySlice indices, uint32 value) { - LiteralUtil::Set(result.get(), indices, value); - }); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); - } + + const Shape& shape = original.shape(); + DimensionVector base(ShapeUtil::Rank(shape), 0); + DimensionVector copy_size(shape.dimensions().begin(), + shape.dimensions().end()); + + TF_CHECK_OK(Copy(original, base, result.get(), base, copy_size)); + return result; } /* static */ StatusOr> LiteralUtil::Reshape( @@ -235,25 +350,19 @@ namespace xla { if (ShapeUtil::IsTuple(input.shape())) { return InvalidArgument("Reshape does not support tuples."); } - + std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) { - return Unimplemented( - "Input shape must have a monotonic layout where dimension 0 is major, " - "was: %s", - LayoutUtil::HumanString(input.shape().layout()).c_str()); + std::vector minor_to_major(ShapeUtil::Rank(input.shape())); + std::iota(minor_to_major.rbegin(), minor_to_major.rend(), + static_cast(0)); + output = Relayout(input, LayoutUtil::MakeLayout(minor_to_major)); + } else { + output = CloneToUnique(input); } - std::vector layout(dimensions.size()); - std::iota(layout.rbegin(), layout.rend(), 0); - // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - std::unique_ptr output = CloneToUnique(input); - output->clear_shape(); - output->mutable_shape()->set_element_type(input.shape().element_type()); - for (int64 dimension : dimensions) { - output->mutable_shape()->add_dimensions(dimension); - } - *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout); + *output->mutable_shape() = + ShapeUtil::MakeShape(input.shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(input.shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -267,73 +376,42 @@ namespace xla { return std::move(output); } -namespace { - -template -void TransposeLiteralInternal(const Literal& original, - tensorflow::gtl::ArraySlice permutation, - Literal* result) { - std::vector new_indices(ShapeUtil::Rank(original.shape())); - LiteralUtil::EachCell( - original, [&](tensorflow::gtl::ArraySlice indices, T value) { - for (int64 i = 0; i < indices.size(); ++i) { - new_indices[i] = indices[permutation[i]]; - } - LiteralUtil::Set(result, new_indices, value); - }); -} -} // namespace - /* static */ std::unique_ptr LiteralUtil::Transpose( const Literal& original, tensorflow::gtl::ArraySlice permutation) { CHECK(!ShapeUtil::IsTuple(original.shape())) - << "tuple is not supported for transpose"; - std::vector dimension_numbers(ShapeUtil::Rank(original.shape())); - std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0); - CHECK(std::is_permutation(permutation.begin(), permutation.end(), - dimension_numbers.begin())) - << "given permutation is not a permutation of dimension numbers"; - std::vector new_dimension_sizes; - for (const int64 dim : permutation) { - new_dimension_sizes.push_back(original.shape().dimensions(dim)); - } - const auto result_shape = ShapeUtil::MakeShape( - original.shape().element_type(), new_dimension_sizes); - std::unique_ptr result = CloneToUnique(original); - *result->mutable_shape() = result_shape; - const PrimitiveType primitive_type = original.shape().element_type(); - switch (primitive_type) { - case F32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case F64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case PRED: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U8: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U32: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case S64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - case U64: - TransposeLiteralInternal(original, permutation, result.get()); - return result; - default: - LOG(FATAL) << "not yet implemented: " - << PrimitiveType_Name(primitive_type); + << "Tuple is not supported for transpose"; + CHECK(IsPermutation(permutation, ShapeUtil::Rank(original.shape()))) + << "Given permutation is not a permutation of dimension numbers"; + // To transpose the array, we just permute the dimensions and layout, and + // do a straight memory copy of the raw data set. + // This is considerably faster than iterating over every array element using + // the EachCell<>() and Set<>() APIs. + std::vector inverse_permutation = InversePermutation(permutation); + Shape shape = + ShapeUtil::PermuteDimensions(inverse_permutation, original.shape()); + // Replace the layout with one affine to the original shape, such that a + // transpose operation can be performed by leaving the flat values + // representation intact. + // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation. + // The shape with affine layout resulting from that operation will be + // F32[8,11]{0,1}, since it leave the original most minor (the 8 sized), the + // most minor. + // Essentially, given MinMaj(Di) the position of the Di dimension within the + // minor to major vector, and given T(Di) the index that the original Di + // dimension has within the transposed array, a layout is affine if + // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major + // vector of the affine layout. + Layout* layout = shape.mutable_layout(); + layout->clear_minor_to_major(); + for (auto index : original.shape().layout().minor_to_major()) { + layout->add_minor_to_major(inverse_permutation[index]); } + std::unique_ptr new_literal = CreateFromShape(shape); + DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + ShapeUtil::ByteSizeOf(original.shape())); + std::memcpy(MutableInternalData(new_literal.get()), InternalData(original), + ShapeUtil::ByteSizeOf(original.shape())); + return new_literal; } /* static */ std::unique_ptr LiteralUtil::Slice( @@ -342,7 +420,7 @@ void TransposeLiteralInternal(const Literal& original, CHECK(!ShapeUtil::IsTuple(literal.shape())) << "tuple is not supported for reshape"; - std::vector result_dimensions; + DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum)); @@ -358,7 +436,7 @@ void TransposeLiteralInternal(const Literal& original, *result_literal->mutable_shape() = result_shape; Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get()); - std::vector new_indices(ShapeUtil::Rank(result_shape)); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: LiteralUtil::EachCell( @@ -425,6 +503,8 @@ void TransposeLiteralInternal(const Literal& original, return tensorflow::strings::StrCat(Get(literal, multi_index)); case F64: return tensorflow::strings::StrCat(Get(literal, multi_index)); + case F16: + return tensorflow::strings::StrCat(Get(literal, multi_index)); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(literal.shape().element_type()), "]"); @@ -579,6 +659,8 @@ void TransposeLiteralInternal(const Literal& original, return reinterpret_cast(literal.f32s().data()); case F64: return reinterpret_cast(literal.f64s().data()); + case F16: + return reinterpret_cast(literal.f16s().data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(literal.shape().element_type()); @@ -593,38 +675,34 @@ void TransposeLiteralInternal(const Literal& original, CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); switch (literal->shape().element_type()) { case PRED: - GetMutableRepeatedField(literal)->Resize(num_elements, false); + Resize(num_elements, false, literal); + break; + case S8: + Resize(num_elements, 0, literal); break; case U8: - // u8s is an optional "bytes", rather than a repeated field. Therefore its - // access methods are somewhat different from the others. - literal->mutable_u8s()->resize(num_elements, 0); + Resize(num_elements, 0, literal); break; case S32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case S64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case U32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case U64: - GetMutableRepeatedField(literal)->Resize( - num_elements, - /*value=*/0); + Resize(num_elements, 0, literal); break; case F32: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0f); + Resize(num_elements, 0, literal); break; case F64: - GetMutableRepeatedField(literal)->Resize(num_elements, - /*value=*/0.0); + Resize(num_elements, 0, literal); + break; + case F16: + Resize(num_elements, static_cast(0.0f), literal); break; default: LOG(FATAL) << "primitive type not supported in literals: " @@ -662,6 +740,9 @@ void TransposeLiteralInternal(const Literal& original, case F64: actual = literal.f64s_size(); break; + case F16: + actual = literal.f16s().size() / sizeof(half); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -680,50 +761,16 @@ void TransposeLiteralInternal(const Literal& original, /* static */ void LiteralUtil::EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell) { - if (ShapeUtil::Rank(literal.shape()) == 1) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - per_cell({i0}, GetAsString(literal, {i0})); - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 2) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - per_cell({i0, i1}, GetAsString(literal, {i0, i1})); - } - } + const std::function indices, + const string& value)>& per_cell) { + if (ShapeUtil::HasZeroElements(literal.shape())) { return; } - - if (ShapeUtil::Rank(literal.shape()) == 3) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); - } - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 4) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { - per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); - } - } - } - } - return; - } - - LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + literal.shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(literal, indices)); + } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } namespace { @@ -786,6 +833,8 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, return EqualElements(literal1, literal2, 0, &multi_index); case F64: return EqualElements(literal1, literal2, 0, &multi_index); + case F16: + return EqualElements(literal1, literal2, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type " << PrimitiveType_Name(literal1.shape().element_type()); @@ -794,96 +843,175 @@ bool EqualElements(const Literal& literal1, const Literal& literal2, } template <> -/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( - const Literal& literal) { - CHECK(literal.shape().element_type() == PRED); - return literal.preds(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_preds(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == PRED); - return literal->mutable_preds(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = literal->mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U32); - return literal.u32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + auto values = literal->mutable_u8s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == U32); - return literal->mutable_u32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_s32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == U64); - return AsUInt64Slice(literal.u64s()); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_u32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == U64); - return literal->mutable_u64s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && + alignof(int64) == alignof(tensorflow::protobuf_int64), + "The int64 and tensorflow::protobuf_int64 types are not " + "compatible"); + auto values = literal->mutable_s64s(); + // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t + // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is + // necessary from the raw data pointer returned by the mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->mutable_data()), values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S32); - return literal.s32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && + alignof(uint64) == alignof(tensorflow::protobuf_uint64), + "The uint64 and tensorflow::protobuf_uint64 types are not " + "compatible"); + auto values = literal->mutable_u64s(); + // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t + // while tensorflow::uint64 is defined as unsigned long long, a + // reinterpret_cast<> is necessary from the raw data pointer returned by the + // mutable_data() API. + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(values->mutable_data()), values->size()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == S32); - return literal->mutable_s32s(); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_f32s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); } template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == S64); - return AsInt64Slice(literal.s64s()); +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + auto values = literal->mutable_f64s(); + return tensorflow::gtl::MutableArraySlice(values->mutable_data(), + values->size()); +} + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal) { + // C++11 standard, basic_string 21.4.1.5, values should be stored + // contiguously. From C++17 a mutable data() member will be provided. + // TODO - there is an endianess problem here. fix it, or wait for uint16 + // support in protobuf + auto values = literal->mutable_f16s(); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(&(*values)[0]), values->size() / sizeof(half)); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), PRED); + return literal.preds(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.u8s().data()), + literal.u8s().size()); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S8); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.u8s().data()), + literal.u8s().size()); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U32); + return literal.u32s(); +} + +template <> +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), U64); + return AsUInt64Slice(literal.u64s()); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal) { - CHECK(literal->shape().element_type() == S64); - return literal->mutable_s64s(); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S32); + return literal.s32s(); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F32); - return literal->mutable_f32s(); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), S64); + return AsInt64Slice(literal.s64s()); } template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal) { - CHECK(literal.shape().element_type() == F64); + CHECK_EQ(literal.shape().element_type(), F64); return literal.f64s(); } template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal) { - CHECK(literal->shape().element_type() == F64); - return literal->mutable_f64s(); +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal) { + CHECK_EQ(literal.shape().element_type(), F16); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(literal.f16s().data()), + literal.f16s().size() / sizeof(half)); } template @@ -925,6 +1053,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(literal, false); @@ -944,6 +1074,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return AllElementsEqualValue(literal, value); case F64: return AllElementsEqualValue(literal, value); + case F16: + return AllElementsEqualValue(literal, static_cast(value)); default: return false; } @@ -968,6 +1100,8 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { return Get(literal, indices) == 0.0f; case F64: return Get(literal, indices) == 0.0; + case F16: + return Get(literal, indices) == static_cast(0.0f); case PRED: return Get(literal, indices) == false; default: @@ -976,51 +1110,77 @@ static bool AllElementsEqualValue(const Literal& literal, NativeT value) { } template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } +/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_preds()->Resize(num_elements, value); } template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - *literal->mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } +/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u8s()->resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u8s()->resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_s32s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u32s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_s64s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_u64s()->Resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal) { +/* static */ void LiteralUtil::Resize(int64 num_elements, float value, + Literal* literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); + literal->mutable_f32s()->Resize(num_elements, value); } template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal) { +/* static */ void LiteralUtil::Resize(int64 num_elements, double value, + Literal* literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* - repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); + literal->mutable_f64s()->Resize(num_elements, value); +} + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal) { + CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); + literal->mutable_f16s()->resize(num_elements * sizeof(half)); + auto data = GetMutableArraySlice(literal); + for (int i = 0; i < num_elements; i++) { + data[i] = value; + } } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 21bb2e46cf2ebcd72bcce393a1e5526f41757544..8e06f35b33d132ba92ce6309db916940362e5a7b 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -100,9 +101,34 @@ class LiteralUtil { values, const Layout& layout); + // Create a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromShape(const Shape& shape); + + // Create a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to dest_literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and dest_literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + static Status Copy(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + // Creates a new value that has the equivalent value as literal, but conforms // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension + // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension // layout and the value in the cell at any given logical index (i0, i1) will // be the same. // @@ -213,6 +239,11 @@ class LiteralUtil { // Clones literal into an owned unique_ptr version. static std::unique_ptr CloneToUnique(const Literal& literal); + // Returns the linear index of the given index within the literal's + // element_type repeated field. + static int64 LinearIndex(const Literal& literal, + tensorflow::gtl::ArraySlice multi_index); + // Gets or sets an element in the literal at the given index. The index is // CHECKed against the dimension sizes. template @@ -223,6 +254,12 @@ class LiteralUtil { tensorflow::gtl::ArraySlice multi_index, NativeT value); + // Retrieves the mutable array slice interface which can be used to manipulate + // pre-allocated literal values. + template + static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( + Literal* literal); + // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template @@ -257,9 +294,8 @@ class LiteralUtil { // like representation in a protobuf). static void EachCellAsString( const Literal& literal, - std::function indices, - const string& value)> - per_cell); + const std::function indices, + const string& value)>& per_cell); template static void EachCell( const Literal& literal, @@ -315,6 +351,14 @@ class LiteralUtil { const Layout& layout, Literal* literal); + // Populates literal values by calling the generator function for every cell + // in the literal object. + template + static Status Populate( + Literal* literal, + const std::function indexes)>& + generator); + // Creates a Literal of the given dimensions with all elements set to the // given value. template @@ -383,70 +427,73 @@ class LiteralUtil { static_assert(!std::is_same::value, "Cannot map native type to primitive type."); } - template - static tensorflow::protobuf::RepeatedField* GetMutableRepeatedField( - Literal* literal) { - // Make the expression depend on the template parameter NativeT so - // that this compile-time error only apperas if this function is - // instantiated with some concrete type that is not specialized - // below. - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index); + // Internal template helper for the Copy() API, matching its arguments one by + // one. + template + static Status CopyRange(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + Literal* dest_literal, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Utility structure which is used to create the optimal configuration for + // a ShapeUtil::ForEachIndex() scan across two literals. + struct StrideConfig { + StrideConfig(const Shape& source_shape, const Shape& dest_shape, + tensorflow::gtl::ArraySlice dimensions); + + // The dimensions of the stride operation. Essentially every dimension + // will be iterated from base[i] to base[i]+dimensions[i], in step[i] + // steps. + tensorflow::gtl::ArraySlice dimensions; + DimensionVector base; + DimensionVector step; + int64 minor_dimension = 0; + // The size of the strides for source and destination. One of the two + // (the one looping through its most minor dimension) will be 1, while + // the other will be the stride size at the dimension matching the other + // shape most minor dimension being scanned. + int64 dest_stride = 1; + int64 source_stride = 1; + // The size of the inner loop on the most minor dimension. + int64 minor_loop_size = 1; + }; TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); }; // Declarations of template specializations for GetArraySlice and -// GetMutableRepeatedField. The specializations map native type to XLA primitive +// GetMutableArraySlice. The specializations map native type to XLA primitive // type. template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); template <> -/* static */ tensorflow::gtl::ArraySlice -LiteralUtil::GetArraySlice(const Literal& literal); +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice +LiteralUtil::GetArraySlice(const Literal& literal); template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField( - Literal* literal); - template <> /* static */ inline tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal) { @@ -454,22 +501,98 @@ LiteralUtil::GetArraySlice(const Literal& literal) { return literal.f32s(); } -template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); - template <> /* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice(const Literal& literal); template <> -/* static */ tensorflow::protobuf::RepeatedField* -LiteralUtil::GetMutableRepeatedField(Literal* literal); +/* static */ tensorflow::gtl::ArraySlice LiteralUtil::GetArraySlice( + const Literal& literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ tensorflow::gtl::MutableArraySlice +LiteralUtil::GetMutableArraySlice(Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, bool value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint8 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint32 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, float value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, double value, + Literal* literal); + +template <> +/* static */ void LiteralUtil::Resize(int64 num_elements, half value, + Literal* literal); template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { auto literal = MakeUnique(); - PopulateR0(value, literal.get()); + PopulateR0(value, literal.get()); return literal; } @@ -695,12 +818,20 @@ template <> return literal.u8s()[linear_index]; } +template <> +/* static */ inline half LiteralUtil::Get( + const Literal& literal, tensorflow::gtl::ArraySlice multi_index) { + CHECK(literal.shape().element_type() == F16); + int64 linear_index = LinearIndex(literal, multi_index); + return GetArraySlice(literal)[linear_index]; +} + template /* static */ void LiteralUtil::Set( Literal* literal, tensorflow::gtl::ArraySlice multi_index, NativeT value) { int64 linear_index = LinearIndex(*literal, multi_index); - GetMutableRepeatedField(literal)->Set(linear_index, value); + GetMutableArraySlice(literal).at(linear_index) = value; } template <> @@ -760,44 +891,11 @@ template } template -/* static */ void LiteralUtil::PopulateR0(NativeT value, Literal* literal) { +/* static */ inline void LiteralUtil::PopulateR0(NativeT value, + Literal* literal) { *literal->mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {}); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int8 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u8s()->push_back(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(uint64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_u64s()->Add(value); -} - -template <> -/* static */ inline void LiteralUtil::PopulateR0(int64 value, - Literal* literal) { - *literal->mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {}); - literal->mutable_s64s()->Add(value); + Resize(1, value, literal); } template @@ -944,65 +1042,72 @@ template literal); } +template +/* static */ Status LiteralUtil::Populate( + Literal* literal, + const std::function indexes)>& + generator) { + const Shape& shape = literal->shape(); + int64 rank = ShapeUtil::Rank(shape); + TF_RET_CHECK(shape.element_type() == + primitive_util::NativeToPrimitiveType()); + tensorflow::gtl::MutableArraySlice data = + GetMutableArraySlice(literal); + if (rank > 0) { + StrideConfig stride_config(shape, shape, AsInt64Slice(shape.dimensions())); + DimensionVector minor_scan_indexes(rank, 0); + int64 minor_dimension_size = + ShapeUtil::GetDimension(shape, stride_config.minor_dimension); + + auto init_function = [&](const std::vector& indexes) { + int64 index = LinearIndex(*literal, indexes); + std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); + for (int64 i = 0; i < minor_dimension_size; ++i) { + minor_scan_indexes[stride_config.minor_dimension] = i; + data.at(index + i) = generator(minor_scan_indexes); + } + return true; + }; + ShapeUtil::ForEachIndex(shape, stride_config.base, stride_config.dimensions, + stride_config.step, init_function); + } else { + data.at(0) = generator({}); + } + return Status::OK(); +} + template /* static */ void LiteralUtil::PopulateWithValue( NativeT value, tensorflow::gtl::ArraySlice dimensions, Literal* literal) { *literal->mutable_shape() = ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), dimensions); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) { - repeated_field->Add(value); - } + Resize(ShapeUtil::ElementsIn(literal->shape()), value, literal); } -template <> -/* static */ void LiteralUtil::PopulateWithValue( - int64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); - -template <> -/* static */ void LiteralUtil::PopulateWithValue( - uint64 value, tensorflow::gtl::ArraySlice dimensions, - Literal* literal); - template /* static */ std::unique_ptr LiteralUtil::Convert( const Literal& literal) { + const Shape& shape = literal.shape(); auto result_literal = MakeUnique(); - Shape result_shape = literal.shape(); - result_shape.set_element_type( + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = shape; + result_shape->set_element_type( primitive_util::NativeToPrimitiveType()); - *result_literal->mutable_shape() = result_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(result_shape), + LiteralUtil::Reserve(ShapeUtil::ElementsIn(*result_shape), result_literal.get()); - LiteralUtil::EachCell( - literal, - [&](tensorflow::gtl::ArraySlice indices, NativeSrcT value) { - LiteralUtil::Set(result_literal.get(), indices, - static_cast(value)); - }); + tensorflow::gtl::ArraySlice src_data = + GetArraySlice(literal); + tensorflow::gtl::MutableArraySlice dest_data = + GetMutableArraySlice(result_literal.get()); + int64 num_elements = ShapeUtil::ElementsIn(shape); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } return result_literal; } -template -/* static */ void LiteralUtil::Resize(int64 num_elements, NativeT value, - Literal* literal) { - CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements); - tensorflow::protobuf::RepeatedField* repeated_field = - GetMutableRepeatedField(literal); - repeated_field->Resize(num_elements, value); -} - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value, - Literal* literal); - -template <> -/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value, - Literal* literal); - template /* static */ std::unique_ptr LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( @@ -1022,10 +1127,7 @@ LiteralUtil::CreateFullWithMonotonicDim0MajorLayout( template /* static */ std::unique_ptr LiteralUtil::Replicate( const Literal& input, int64 times) { - // Ranks greater than 8 are very rare, so use InlinedVector to store - // the bounds and indices. - static constexpr int kInlineRank = 8; - tensorflow::gtl::InlinedVector bounds = {times}; + DimensionVector bounds = {times}; bounds.reserve(input.shape().dimensions_size() + 1); for (int64 bound : input.shape().dimensions()) { bounds.push_back(bound); @@ -1039,8 +1141,7 @@ template } Reserve(elements, literal.get()); - tensorflow::gtl::InlinedVector output_indices( - bounds.size(), 0); + DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; input_indices.remove_prefix(1); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e53763376bfe58b7c5a811987161cac966d14222..9a09822174d9c93c8195af193f34017268bbc503 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -21,14 +21,18 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + class LiteralUtilTest : public ::testing::Test { protected: LiteralUtilTest() { @@ -101,6 +105,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f32_lit = LiteralUtil::CreateR0(3.14f); ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + + auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -159,9 +166,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_MATCH(testing::PBToVec( - literal->shape().dimensions()), - testing::VectorMatcher({2, 3, 2})); + EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); string result = LiteralUtil::ToString(*literal); const string expected = R"(f32[2,3,2] { { { 1, 2 }, @@ -182,9 +187,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_MATCH( - testing::PBToVec(literal->shape().dimensions()), - testing::VectorMatcher({1, 2, 3, 2})); + EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = LiteralUtil::ToString(*literal); const string expected = R"(f32[1,2,3,2] { { // i0=0 @@ -204,10 +207,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_MATCH( - testing::PBToVec( - literal_r4_2x2x3x3_dim0major_->shape().dimensions()), - testing::VectorMatcher({2, 2, 3, 3})); + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + ElementsAre(2, 2, 3, 3)); string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); const string expected = R"(f32[2,2,3,3] { { // i0=0 @@ -375,6 +376,15 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE( LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + half h8(8.0f); + half h9(9.0f); + EXPECT_TRUE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); + EXPECT_FALSE( + LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::IsAll( *LiteralUtil::CreateR2( @@ -471,6 +481,26 @@ TEST_F(LiteralUtilTest, ReshapeR4) { EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); } +TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { + // clang-format off + // F32[1x3x2x4] + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0minor_); + // F32[1x3x4x2] + auto expected = LiteralUtil::CreateR3WithLayout({ + {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, + {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, + {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, + }, layout_r3_dim0major_); + // clang-format on + auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); +} + TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); @@ -516,27 +546,23 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) { auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->s32s_size(), 6); - EXPECT_MATCH(testing::PBToVec(mat_dim0minor->s32s()), - testing::VectorMatcher({1, 4, 2, 5, 3, 6})); + EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. auto relaid_mat_to_dim0major = LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); - EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0major->s32s()), - testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->s32s_size(), 6); - EXPECT_MATCH(testing::PBToVec(mat_dim0major->s32s()), - testing::VectorMatcher({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. auto relaid_mat_to_dim0minor = LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); - EXPECT_MATCH(testing::PBToVec(relaid_mat_to_dim0minor->s32s()), - testing::VectorMatcher({1, 4, 2, 5, 3, 6})); + EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); } TEST_F(LiteralUtilTest, TestR3LinearLayout) { @@ -558,28 +584,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { EXPECT_EQ(lit_dim0minor->s32s_size(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_MATCH(testing::PBToVec(lit_dim0minor->s32s()), - testing::VectorMatcher(expected_dim0minor)); + EXPECT_THAT(lit_dim0minor->s32s(), + testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. auto relaid_lit_to_dim0major = LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0major->s32s()), - testing::VectorMatcher(expected_dim0major)); + EXPECT_THAT(relaid_lit_to_dim0major->s32s(), + testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->s32s_size(), 12); - EXPECT_MATCH(testing::PBToVec(lit_dim0major->s32s()), - testing::VectorMatcher(expected_dim0major)); + EXPECT_THAT(lit_dim0major->s32s(), + testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. auto relaid_lit_to_dim0minor = LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); - EXPECT_MATCH(testing::PBToVec(relaid_lit_to_dim0minor->s32s()), - testing::VectorMatcher(expected_dim0minor)); + EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), + testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { @@ -645,6 +671,30 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { + Literal output; + half h(0.25f); + LiteralUtil::PopulateWithValue(h, {}, &output); + auto expected = LiteralUtil::CreateR0(h); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { + Literal output; + half h(0.5f); + LiteralUtil::PopulateWithValue(h, {3}, &output); + auto expected = LiteralUtil::CreateR1({h, h, h}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { + Literal output; + half h(2.0f); + LiteralUtil::PopulateWithValue(h, {2, 2}, &output); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -656,5 +706,156 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); } +TEST_F(LiteralUtilTest, Copy) { + const int64 dimensions[] = {17, 15, 34, 21}; + const int64 layouts[][4] = { + {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; + for (const auto& layout : layouts) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), dimensions, layout); + auto blank = LiteralUtil::CreateFromShape(shape); + auto source = LiteralUtil::CreateFromShape(shape); + const int64 zero_base[] = {0, 0, 0, 0}; + const int64 step[] = {1, 1, 1, 1}; + uint32 seqnr = 0; + auto init_proc = [&](const std::vector& indexes) { + LiteralUtil::Set(source.get(), indexes, ++seqnr); + return true; + }; + + ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + init_proc); + + const int64 src_base[] = {3, 1, 5, 7}; + const int64 dest_base[] = {6, 4, 12, 2}; + const int64 copy_size[] = {7, 8, 11, 9}; + + TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, + copy_size)); + std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); + std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); + bool matched = true; + auto check_proc = [&](const std::vector& indexes) { + std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); + std::transform(source_indexes.begin(), source_indexes.end(), src_base, + source_indexes.begin(), std::plus()); + std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); + std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, + blank_indexes.begin(), std::plus()); + auto bval = LiteralUtil::Get(*blank, blank_indexes); + matched = (bval != 0 && + bval == LiteralUtil::Get(*source, source_indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + check_proc); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, CopyScalars) { + auto zero = LiteralUtil::CreateR0(0); + auto nine = LiteralUtil::CreateR0(9); + TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); + EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); + + auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); + EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); + TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); + EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); +} + +TEST_F(LiteralUtilTest, F16) { + // Verify that the internal data views are consistent and that they + // are in little endian format + // TODO - modify if we make the data format machine endianess dependent + auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + Literal* l1 = m1.get(); + const char* d1 = (const char*)LiteralUtil::InternalData(*l1); + EXPECT_EQ(d1[0], 0); + EXPECT_EQ(d1[1], 0); + EXPECT_EQ(d1[2], 0); + EXPECT_EQ(d1[3], 0); + EXPECT_EQ(d1[4], 0); + EXPECT_EQ(d1[5], 0); + EXPECT_EQ(d1[6], 0); + EXPECT_EQ(d1[7], 0); + EXPECT_EQ(LiteralUtil::InternalData(*l1), + LiteralUtil::MutableInternalData(l1)); + + half h1(1.0f); + half h2(2.0f); + auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + Literal* l2 = m2.get(); + const char* d2 = (const char*)LiteralUtil::InternalData(*l2); + EXPECT_EQ(d2[0], 0); + EXPECT_EQ(d2[1], 0x3C); + EXPECT_EQ(d2[2], 0); + EXPECT_EQ(d2[3], 0x40); + EXPECT_EQ(d2[4], 0); + EXPECT_EQ(d2[5], 0x40); + EXPECT_EQ(d2[6], 0); + EXPECT_EQ(d2[7], 0x3C); + EXPECT_EQ(LiteralUtil::InternalData(*l2), + LiteralUtil::MutableInternalData(l2)); +} + +TEST_F(LiteralUtilTest, Populate) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{16}, {0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = LiteralUtil::CreateFromShape(shape); + auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return LiteralUtil::LinearIndex(*literal, indexes) + 17; + }; + TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](const std::vector& indexes) { + auto value = LiteralUtil::Get(*literal, indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, ConvertR4) { + // clang-format off + auto original = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + auto expected = LiteralUtil::CreateR4WithLayout({{ + {{10, 11, 12, 13}, {14, 15, 16, 17}}, + {{18, 19, 20, 21}, {22, 23, 24, 25}}, + {{26, 27, 28, 29}, {30, 31, 32, 33}}, + }}, layout_r4_dim0major_); + // clang-format on + auto converted = LiteralUtil::Convert(*original); + + EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index cd7c42f6e17e15b5e1c6ebfa1f24a40a9003a63e..0d4ddc239243b79d47b6a1672b65abe9b23e7b52 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -38,7 +38,8 @@ void MetricTableReport::SetEntryName(string entry_name) { void MetricTableReport::SetShowAllEntries() { max_entries_to_show_ = std::numeric_limits::max(); - max_metric_proportion_to_show = 1.1; // more than 100% + max_entries_per_category_to_show_ = std::numeric_limits::max(); + max_metric_proportion_to_show_ = 1.1; // more than 100% } void MetricTableReport::SetShowCategoryTable() { show_category_table_ = true; } @@ -141,7 +142,7 @@ void MetricTableReport::AppendCategoryTable() { int64 categories_shown = 0; for (const auto& category : categories) { if (categories_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++categories_shown; @@ -156,15 +157,14 @@ void MetricTableReport::AppendCategoryTable() { entry_name_, ")"); AppendTableRow(text, category.metric_sum, metric_sum); - // Show the top few entries in the category. - const int64 kMaxToShow = 5; + // Show the top entries in the category. const char* const kIndentPrefix = " * "; - int64 entries_to_show = - std::min(kMaxToShow, category.entries.size()); - if (category.entries.size() == kMaxToShow + 1) { + int64 entries_to_show = std::min(max_entries_per_category_to_show_, + category.entries.size()); + if (category.entries.size() == entries_to_show + 1) { // May as well show the last entry on the line that would otherwise say // that there is a single entry not shown. - entries_to_show = category.entries.size(); + ++entries_to_show; } for (int64 i = 0; i < entries_to_show; ++i) { AppendLine(kIndentPrefix, MetricPercent(category.entries[i]->metric), " ", @@ -193,7 +193,7 @@ void MetricTableReport::AppendEntryTable() { int64 entries_shown = 0; for (const auto& entry : entries_) { if (entries_shown >= max_entries_to_show_ || - metric_sum / expected_metric_sum_ > max_metric_proportion_to_show) { + metric_sum / expected_metric_sum_ > max_metric_proportion_to_show_) { break; } ++entries_shown; diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index e967627bff4446a695bfae514faac4b1acca4968..818fb1d3fe0b8bbe1a8eba363ff6445e2f3df9d2 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -103,6 +103,7 @@ class MetricTableReport { private: static constexpr double kDefaultMaxMetricProportionToShow = 0.99; static constexpr int64 kDefaultMaxEntriesToShow = 100; + static constexpr int64 kDefaultMaxEntriesPerCategoryToShow = 5; // Append all parameters to the report. template @@ -162,7 +163,8 @@ class MetricTableReport { // These members control how many categories and entries to show in tables. int64 max_entries_to_show_ = kDefaultMaxEntriesToShow; - double max_metric_proportion_to_show = kDefaultMaxMetricProportionToShow; + int64 max_entries_per_category_to_show_ = kDefaultMaxEntriesPerCategoryToShow; + double max_metric_proportion_to_show_ = kDefaultMaxMetricProportionToShow; // The report that is being created. string report_; diff --git a/tensorflow/compiler/xla/port/BUILD b/tensorflow/compiler/xla/port/BUILD deleted file mode 100644 index 6fc5f1185c9d56075f18928e4b2c8e3819cf9ddd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/port/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), - visibility = ["//tensorflow/compiler/xla:internal"], -) - -cc_library( - name = "initialize", - hdrs = ["initialize.h"], - visibility = [ - "//tensorflow/compiler/xla:__subpackages__", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e3909ae8e9736351d3ee91332572b5db62727289..e4e37177a2d74e6da20300f1439942a146ad8d49 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return F16; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 78f0ee6f592d9b9ec2ed85f23297634c5e2e4d41..162a11c7d2966346979b98c804917203f82c806c 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -75,6 +75,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); bool IsFloatingPointType(PrimitiveType type); @@ -150,6 +152,10 @@ template <> struct PrimitiveTypeToNative { using type = double; }; +template <> +struct PrimitiveTypeToNative { + using type = half; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 86c9c3b1ac38d755effad733590f78aafa9571db..4194d5fc6be0ad552e9fe6dd14b51fa0a67f2eca 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -180,14 +180,28 @@ ReferenceUtil::ReduceWindow4DGeneric( const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow4DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1], window_counts[2], window_counts[3]); @@ -649,4 +663,39 @@ ReferenceUtil::ReduceToRowArray2D( return result; } +/* static */ Array4D ReferenceUtil::PadArray4D( + const Array4D& operand, const PaddingConfig& padding, + const float pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented"; + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], + output_bounds[3]); + result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + } + *value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1], + indices[2] - pad_low[2], indices[3] - pad_low[3]); + }); + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 9e0f247203866d544595a877fabd33af148cc307..f58f0bdc9f51dff62c10dda4aba7aac03e689ce7 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -162,6 +162,12 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow4DGeneric( + const Array4D& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. @@ -395,7 +401,51 @@ class ReferenceUtil { const Array2D& operand, const PaddingConfig& padding, const float pad); + // Returns the result of a 4D pad on an input array. + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const float pad); + + // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running + // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... + // + // The given arrays must have the same size and element type, and the return + // type of f must be implicitly convertible to the arrays' element type. + // + // Example usage: + // + // Array2D x, y, z = ...; + // std::unique_ptr result = ReferenceUtil::ApplyElementwise2D( + // [](float a, float b, float c) { return a * b + c; }, x, y, z); + // + template + static std::unique_ptr> ApplyElementwise2D( + F&& f, const Array2D& array1, const Array2D&... arrays) { + AssertSameSize2D(array1, arrays...); + auto result = MakeUnique>(array1.n1(), array1.n2()); + for (int64 i = 0; i < array1.n1(); ++i) { + for (int64 j = 0; j < array1.n2(); ++j) { + (*result)(i, j) = f(array1(i, j), arrays(i, j)...); + } + } + return result; + } + private: + template + static void AssertSameSize2D(const Array2D& array1, + const Array2D& array2, + const Array2D&... arrays) { + static_assert(std::is_same::value, "Args must be same type."); + CHECK_EQ(array1.n1(), array2.n1()); + CHECK_EQ(array1.n2(), array2.n2()); + AssertSameSize2D(array2, arrays...); + } + + // Recursive base case for AssertSameSize2D. + template + static void AssertSameSize2D(const Array1& array1) {} + TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil); }; diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index c53351ca93e81f70920291019798f16f0f1c6a57..f839ac019df07c5c5e07eed856ea55463bb3efae 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -52,9 +52,9 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -62,32 +62,32 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto result_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *result_literal, + auto actual_literal = LiteralUtil::CreateR1(*result); + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal, + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,9 +96,9 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto result_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *result_literal, ErrorSpec(0.0001)); + *actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -107,11 +107,11 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -124,11 +124,11 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto result_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); } @@ -302,5 +302,17 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, ApplyElementwise2D) { + Array2D a({{1, 2}, {3, 4}}); + Array2D b({{10, 20}, {30, 40}}); + Array2D c({{100, 200}, {300, 400}}); + + auto actual = ReferenceUtil::ApplyElementwise2D( + [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); + auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b9118fab2549689d045a1caf826b9d3937019e1c..3c53cf4dd3c5d663cf703cce3d479cb3a2cea2eb 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -52,6 +52,7 @@ cc_test( deps = [ ":shape_inference", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -64,8 +65,42 @@ cc_test( srcs = ["hlo_opcode_test.cc"], deps = [ ":hlo", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "hlo_evaluator", + srcs = ["hlo_evaluator.cc"], + hdrs = ["hlo_evaluator.h"], + deps = [ + ":hlo", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_evaluator_test", + srcs = ["hlo_evaluator_test.cc"], + deps = [ + ":hlo", + ":hlo_evaluator", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -88,6 +123,7 @@ cc_library( "hlo_opcode.h", ], deps = [ + ":hlo_module_config", ":name_uniquer", ":versioned_computation_handle", "//tensorflow/compiler/xla:literal_util", @@ -105,6 +141,27 @@ cc_library( ], ) +cc_library( + name = "hlo_matchers", + testonly = 1, + srcs = ["hlo_matchers.cc"], + hdrs = ["hlo_matchers.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "hlo_matchers_test", + srcs = ["hlo_matchers_test.cc"], + deps = [ + ":hlo_matchers", + "//tensorflow/compiler/xla:shape_util", + ], +) + cc_library( name = "versioned_computation_handle", srcs = ["versioned_computation_handle.cc"], @@ -122,7 +179,9 @@ cc_test( deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -137,7 +196,6 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], @@ -151,6 +209,42 @@ cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "flatten_call_graph", + srcs = ["flatten_call_graph.cc"], + hdrs = ["flatten_call_graph.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "flatten_call_graph_test", + srcs = ["flatten_call_graph_test.cc"], + deps = [ + ":call_graph", + ":flatten_call_graph", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -185,10 +279,12 @@ cc_test( name = "user_computation_test", srcs = ["user_computation_test.cc"], deps = [ + ":hlo_matchers", ":user_computation", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", @@ -311,6 +407,27 @@ cc_library( ], ) +cc_library( + name = "compile_only_service", + srcs = ["compile_only_service.cc"], + hdrs = ["compile_only_service.h"], + deps = [ + ":backend", + ":compiler", + ":computation_layout", + ":computation_tracker", + ":platform_util", + ":service", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "cpu_plugin", deps = [ @@ -451,6 +568,7 @@ cc_library( hdrs = ["computation_tracker.h"], deps = [ ":hlo", + ":hlo_module_config", ":session_proto", ":user_computation", ":versioned_computation_handle", @@ -504,7 +622,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", ], ) @@ -515,9 +632,6 @@ cc_test( ":hlo", ":liveness_util", ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test_main", ], @@ -532,6 +646,7 @@ cc_library( "buffer_liveness.h", ], deps = [ + ":call_graph", ":hlo", ":hlo_ordering", ":liveness_util", @@ -572,8 +687,8 @@ cc_library( ], deps = [ ":buffer_liveness", - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -593,11 +708,17 @@ cc_test( srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":call_graph", ":computation_tracker", + ":copy_insertion", ":cpu_plugin", + ":flatten_call_graph", ":hlo", + ":hlo_ordering", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -607,56 +728,38 @@ cc_test( ], ) -cc_library( - name = "heap_simulator", - srcs = [ - "heap_simulator.cc", - ], - hdrs = [ - "heap_simulator.h", - ], - deps = [ - ":hlo", - ":liveness_util", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], -) - cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ - ":heap_simulator", ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) +# The hlo_ordering library contains both hlo_ordering and heap_simulator because +# they are mutually dependent. cc_library( name = "hlo_ordering", srcs = [ + "heap_simulator.cc", "hlo_ordering.cc", ], hdrs = [ + "heap_simulator.h", "hlo_ordering.h", ], deps = [ - ":heap_simulator", + ":call_graph", ":hlo", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -710,6 +813,7 @@ cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], deps = [ + ":hlo_matchers", ":instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test_main", @@ -743,10 +847,11 @@ cc_test( ":algebraic_simplifier", ":cpu_plugin", ":hlo", + ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -764,7 +869,9 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", ], ) @@ -773,9 +880,11 @@ cc_test( srcs = ["reshape_mover_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":reshape_mover", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -806,10 +915,11 @@ cc_test( deps = [ ":cpu_plugin", ":hlo", + ":hlo_matchers", ":inliner", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -941,8 +1051,10 @@ cc_test( deps = [ ":cpu_plugin", ":hlo", + ":hlo_matchers", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test_main", @@ -972,7 +1084,7 @@ cc_test( ":hlo", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1024,10 +1136,12 @@ cc_test( srcs = ["tuple_points_to_analysis_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":instruction_fusion", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1089,6 +1203,7 @@ cc_library( ":buffer_liveness", ":hlo", ":hlo_pass", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:status_macros", @@ -1103,13 +1218,14 @@ cc_test( name = "copy_insertion_test", srcs = ["copy_insertion_test.cc"], deps = [ - ":buffer_liveness", ":copy_insertion", ":cpu_plugin", ":hlo", + ":hlo_matchers", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1137,16 +1253,7 @@ cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], - deps = [ - ":hlo", - ":hlo_pass", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - ], + deps = [":hlo_pass"], ) cc_library( @@ -1156,10 +1263,12 @@ cc_library( deps = [ ":buffer_liveness", ":call_graph", + ":flatten_call_graph", ":hlo", ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", + ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1177,6 +1286,7 @@ cc_test( deps = [ ":cpu_plugin", ":hlo", + ":hlo_matchers", ":hlo_ordering", ":hlo_rematerialization", "//tensorflow/compiler/xla:shape_util", @@ -1203,6 +1313,7 @@ cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) @@ -1215,10 +1326,12 @@ cc_test( ":computation_layout", ":cpu_plugin", ":hlo", + ":hlo_matchers", ":layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -1274,7 +1387,6 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -1288,6 +1400,7 @@ cc_test( ":cpu_plugin", ":hlo", ":hlo_cse", + ":hlo_matchers", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1310,13 +1423,34 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) +cc_test( + name = "hlo_constant_folding_test", + srcs = ["hlo_constant_folding_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":hlo_constant_folding", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -1403,6 +1537,33 @@ cc_test( ], ) +cc_library( + name = "hlo_tfgraph_builder", + srcs = ["hlo_tfgraph_builder.cc"], + hdrs = ["hlo_tfgraph_builder.h"], + visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "hlo_tfgraph_builder_test", + srcs = ["hlo_tfgraph_builder_test.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "hlo_graph_dumper", srcs = [ @@ -1412,6 +1573,7 @@ cc_library( deps = [ ":hlo", ":hlo_execution_profile", + ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1429,7 +1591,9 @@ cc_library( deps = [ ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/core:lib", ], @@ -1440,11 +1604,15 @@ cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/core:lib", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4c058484b9fbdeeabd5c240cc85b7439181896df..6e6da38f9e33bd8bd7723a3a96608a2eacf2a5a2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -51,6 +51,16 @@ bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { LiteralUtil::IsAll(operand->literal(), value); } +bool IsAll(const HloInstruction* op, int8 value) { + if (IsLiteralWithValue(op, value)) { + return true; + } + if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) { + return true; + } + return false; +} + // Returns whether the given transpose produces a result which is bit-wise // identical to its operand and thus may be replaced with a bitcast. bool TransposeIsBitcast(const HloInstruction* transpose) { @@ -112,6 +122,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; Status HandleConvert(HloInstruction* convert, @@ -146,9 +160,19 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice dimensions, HloComputation* function) override; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override; + Status HandleReverse(HloInstruction* reverse, HloInstruction* operand) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + Status HandleDynamicSlice(HloInstruction* slice, HloInstruction* operand, + HloInstruction* start_indices) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override; Status HandleTranspose(HloInstruction* transpose) override; @@ -210,6 +234,29 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast); + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithNewInstruction( + HloInstruction* old_instruction, + std::unique_ptr new_instruction) { + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + old_instruction, std::move(new_instruction))); + changed_ = true; + return Status::OK(); + } + + // Replaces the existing HLO instruction old_instruction, with + // new_instruction, and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceInstruction(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(old_instruction, new_instruction)); + changed_ = true; + return Status::OK(); + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -258,8 +305,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast( auto bitcast = computation_->AddInstruction( HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast, instruction->mutable_operand(0))); - TF_CHECK_OK(computation_->ReplaceInstruction(instruction, bitcast)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( @@ -267,9 +313,7 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( if (!SameShape(old_instruction, new_instruction)) { return false; } - TF_CHECK_OK( - computation_->ReplaceInstruction(old_instruction, new_instruction)); - changed_ = true; + TF_CHECK_OK(ReplaceInstruction(old_instruction, new_instruction)); return true; } @@ -278,12 +322,12 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, HloInstruction* rhs) { // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { return Status::OK(); } // 0 + A => A VLOG(10) << "trying transform [0 + A => A]: " << add->ToString(); - if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { + if (IsAll(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) { return Status::OK(); } @@ -297,12 +341,45 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + if (operands.size() == 1) { + // Unary concatenates are useless. + ReplaceInstructionIfSameShape(concatenate, operands[0]); + return Status::OK(); + } + // Filter out and remove empty operands. + std::vector nonempty_operands; + for (HloInstruction* operand : operands) { + if (!ShapeUtil::HasZeroElements(operand->shape())) { + nonempty_operands.push_back(operand); + } + } + if (nonempty_operands.size() < operands.size()) { + HloInstruction* replacement; + if (nonempty_operands.empty()) { + replacement = operands[0]; + } else if (nonempty_operands.size() == 1) { + replacement = nonempty_operands[0]; + } else { + replacement = + computation_->AddInstruction(concatenate->CloneWithNewOperands( + concatenate->shape(), nonempty_operands)); + } + VLOG(10) << "trying to replace " << concatenate->ToString() << " with " + << replacement->ToString(); + ReplaceInstructionIfSameShape(concatenate, replacement); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) { // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); - if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { + if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { return Status::OK(); } @@ -314,8 +391,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, HloInstruction* rhs) { // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { return Status::OK(); } @@ -326,8 +402,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, computation_->AddInstruction(HloInstruction::CreateBinary( divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), rhs->mutable_operand(0))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, subtract)); } @@ -354,8 +429,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -364,8 +438,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, rhs->mutable_operand(0), lhs->mutable_operand(0))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -373,8 +446,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, lhs, rhs)); } @@ -398,8 +470,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, {0}, add_reduce_computation)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } @@ -438,8 +509,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, {rhs->shape().dimensions(1)}), multiply, zero, {0}, add_reduce_computation)); } - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } @@ -465,8 +535,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::MakeShape(dot->shape().element_type(), {lhs->shape().dimensions(0)}), multiply, zero, {1}, add_reduce_computation)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( dot, HloInstruction::CreateReshape(dot->shape(), reduce)); } return Status::OK(); @@ -477,14 +546,12 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, HloInstruction* rhs) { // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); - if (IsLiteralWithValue(rhs, 1) && - ReplaceInstructionIfSameShape(multiply, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { return Status::OK(); } // 1*A => A VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString(); - if (IsLiteralWithValue(lhs, 1) && - ReplaceInstructionIfSameShape(multiply, rhs)) { + if (IsAll(lhs, 1) && ReplaceInstructionIfSameShape(multiply, rhs)) { return Status::OK(); } return Status::OK(); @@ -605,8 +672,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " "n(broadcast(X)) == n(X)"; - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand)); } @@ -618,8 +684,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, broadcast->dimensions())); } @@ -639,8 +704,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( broadcast, HloInstruction::CreateBroadcast(broadcast->shape(), operand->mutable_operand(0), dims)); @@ -683,65 +747,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } -template -static std::unique_ptr ConvertIfTypesMatch( - const Literal& src_literal) { - CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - - return HloInstruction::CreateConstant( - LiteralUtil::Convert::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal)); -} - -template -static std::unique_ptr ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - -static std::unique_ptr ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); - CONVERT_IF_DEST_TYPE_MATCHES(PRED) - CONVERT_IF_DEST_TYPE_MATCHES(S8) - CONVERT_IF_DEST_TYPE_MATCHES(S32) - CONVERT_IF_DEST_TYPE_MATCHES(S64) - CONVERT_IF_DEST_TYPE_MATCHES(U8) - CONVERT_IF_DEST_TYPE_MATCHES(U32) - CONVERT_IF_DEST_TYPE_MATCHES(U64) - CONVERT_IF_DEST_TYPE_MATCHES(F32) - CONVERT_IF_DEST_TYPE_MATCHES(F64) -#undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. - default: - LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " - << PrimitiveType_Name(src_literal.shape().element_type()); - } -} - // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. @@ -750,16 +755,7 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, PrimitiveType src_type = operand->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - changed_ = true; - return computation_->ReplaceInstruction(convert, operand); - } - if (operand->opcode() == HloOpcode::kConstant) { - const Literal& src_literal = operand->literal(); - std::unique_ptr new_constant = - ConvertIfSrcTypeMatches(src_literal, dest_type); - changed_ = true; - return computation_->ReplaceWithNewInstruction(convert, - std::move(new_constant)); + return ReplaceInstruction(convert, operand); } return Status::OK(); } @@ -845,8 +841,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { std::unique_ptr slice = HloInstruction::CreateSlice( pad->shape(), nonzero_pad, start_indices, end_indices); - changed_ = true; - return computation_->ReplaceWithNewInstruction(pad, std::move(slice)); + return ReplaceWithNewInstruction(pad, std::move(slice)); } return Status::OK(); @@ -856,7 +851,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 0)) { + if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(power->shape().element_type()))); std::unique_ptr ones; @@ -866,30 +861,27 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, ones = HloInstruction::CreateBroadcast( power->shape(), computation_->AddInstruction(std::move(one)), {}); } - changed_ = true; - return computation_->ReplaceWithNewInstruction(power, std::move(ones)); + return ReplaceWithNewInstruction(power, std::move(ones)); } VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { + if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) { return Status::OK(); } VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, 2)) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + if (IsAll(rhs, 2)) { + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kMultiply, lhs, lhs)); } VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); - if (IsLiteralWithValue(rhs, -1)) { + if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( LiteralUtil::One(rhs->shape().element_type())))); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); } @@ -967,17 +959,24 @@ StatusOr AlgebraicSimplifierVisitor:: Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto operand = reshape->mutable_operand(0); + // Reshape directly to empty constant if the shape contains zero-element + // dimension. + if (ShapeUtil::HasZeroElements(reshape->shape())) { + auto empty_constant = HloInstruction::CreateConstant( + LiteralUtil::CreateFromShape(reshape->shape())); + + return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); + } + // Delete no-op reshapes, i.e. where shape = operand shape. if (SameShape(reshape, operand)) { VLOG(10) << "deleting no-op reshape"; - changed_ = true; - return computation_->ReplaceInstruction(reshape, operand); + return ReplaceInstruction(reshape, operand); } // Merge reshapes. if (HloOpcode::kReshape == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } @@ -986,8 +985,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); if (opt_dims.first) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), @@ -1023,8 +1021,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse, }; if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), dim_is_one)) { - changed_ = true; - return computation_->ReplaceInstruction(reverse, operand); + return ReplaceInstruction(reverse, operand); } return Status::OK(); } @@ -1038,12 +1035,31 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDynamicSlice( + HloInstruction* dynamic_slice, HloInstruction* operand, + HloInstruction* start_indices) { + if (ShapeUtil::IsScalar(dynamic_slice->shape())) { + return ReplaceInstruction(dynamic_slice, operand); + } + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice, HloInstruction* operand, + HloInstruction* update, HloInstruction* start_indices) { + // DynamicUpdateSlice on a scalar just passes through the update argument. + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + return ReplaceInstruction(dynamic_update_slice, update); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { if (ShapeUtil::HasZeroElements(arg->shape()) || ShapeUtil::HasZeroElements(reduce->shape())) { - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); return Status::OK(); @@ -1056,7 +1072,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce( for (auto dim : dimensions) { new_reduce_dimensions.push_back(transpose_dimensions[dim]); } - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( reduce->shape(), arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); @@ -1100,7 +1116,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce( new_reduce_dimensions.push_back(i); } } - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReduce( reduce->shape(), arg->mutable_operand(0), init_value, new_reduce_dimensions, function)); @@ -1111,27 +1127,84 @@ Status AlgebraicSimplifierVisitor::HandleReduce( ShapeUtil::HasZeroElements(arg->shape())) { auto reshape = computation_->AddInstruction( HloInstruction::CreateReshape(reduce->shape(), arg)); - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( reduce, HloInstruction::CreateMap(reduce->shape(), {reshape, init_value}, function)); } return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleReduceWindow( + HloInstruction* reduce_window, HloInstruction* operand, + const Window& window, HloComputation* function) { + VLOG(10) << "Considering folding Pad: " << operand->ToString() + << "\ninto reduce-window: " << reduce_window->ToString(); + + // This optimization folds a pad op into reduce_window. + if (operand->opcode() != HloOpcode::kPad) { + VLOG(10) << "Not folding pad into reduce-window as there is no pad."; + return Status::OK(); + } + + // Do not fold interior padding into ReduceWindow since the backends do not + // support it. + const PaddingConfig& pad_config = operand->padding_config(); + if (HasInteriorPadding(pad_config)) { + VLOG(10) << "Not folding pad into reduce-window due to interior padding."; + return Status::OK(); + } + + // If reduce_window already has padding, the pad value of the pad op and the + // init value of reduce_window must match to allow folding the pad. + const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* reduce_init_value = reduce_window->operand(1); + if (pad_value != reduce_init_value) { + // The pad value is usually a constant, so we handle that case and do not + // try to get more fancy about proving equivalence in cases beyond that. + if (pad_value->opcode() != HloOpcode::kConstant || + reduce_init_value->opcode() != HloOpcode::kConstant || + !LiteralUtil::Equal(pad_value->literal(), + reduce_init_value->literal())) { + VLOG(10) + << "Not folding pad into reduce-window due to different pad values."; + return Status::OK(); + } + } + + // Carry out the folding of the pad into reduce_window. + VLOG(10) << "Folding pad into reduce-window."; + Window new_window = window; + const int64 rank = ShapeUtil::Rank(reduce_window->shape()); + TF_RET_CHECK(pad_config.dimensions_size() == rank); + TF_RET_CHECK(window.dimensions_size() == rank); + for (int64 i = 0; i < rank; ++i) { + const auto& pad_dim = pad_config.dimensions(i); + auto& window_dim = *new_window.mutable_dimensions(i); + window_dim.set_padding_low(window_dim.padding_low() + + pad_dim.edge_padding_low()); + window_dim.set_padding_high(window_dim.padding_high() + + pad_dim.edge_padding_high()); + } + return ReplaceWithNewInstruction( + reduce_window, HloInstruction::CreateReduceWindow( + /*shape=*/reduce_window->shape(), + /*operand=*/operand->mutable_operand(0), + /*init_value=*/reduce_window->mutable_operand(1), + /*window=*/new_window, + /*reduce_computation=*/function)); +} + Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; - changed_ = true; - return computation_->ReplaceInstruction(transpose, operand); + return ReplaceInstruction(transpose, operand); } if (HloOpcode::kTranspose == operand->opcode()) { - changed_ = true; - return computation_->ReplaceWithNewInstruction( + return ReplaceWithNewInstruction( transpose, HloInstruction::CreateTranspose( transpose->shape(), operand->mutable_operand(0), ComposePermutations(operand->dimensions(), @@ -1258,9 +1331,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_rhs = add_bitcast(new_filter_shape, rhs); auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); - changed_ = true; - return computation_->ReplaceInstruction(convolution, - add_bitcast(convolution_shape, dot)); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( @@ -1274,8 +1345,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp, max_operand, operand, min_operand); - TF_CHECK_OK(computation_->ReplaceWithNewInstruction(root, std::move(clamp))); - changed_ = true; + TF_CHECK_OK(ReplaceWithNewInstruction(root, std::move(clamp))); return true; } @@ -1348,13 +1418,20 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); - bool changed = - std::any_of(module->computations().begin(), module->computations().end(), - [=](const std::unique_ptr& computation) { - return AlgebraicSimplifierVisitor::Run( - computation.get(), is_layout_sensitive_, - valid_bitcast_callback_, enable_dot_simplification_); - }); + bool changed = false; + // Make a copy of the computations because we may add computations to the + // module, invalidating iteration. + std::vector computations; + for (auto& comp : module->computations()) { + computations.push_back(comp.get()); + } + for (auto& comp : computations) { + if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_, + valid_bitcast_callback_, + enable_dot_simplification_)) { + changed = true; + } + } XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), after:\n" + module->ToString()); return changed; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 27a1c0fec8855810cd016b36b1706a17c0204d63..87d8a7165ccfad587474a0c89e9387597e341d8f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -23,21 +23,25 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } + AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } @@ -66,6 +70,52 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction* bcast = + builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); + builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -157,9 +207,7 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root, add); - EXPECT_EQ(root->operand(0), param1); - EXPECT_EQ(root->operand(1), param2); + EXPECT_THAT(root, op::Add(param1, param2)); } // Test that exp(A)/exp(B) is simplified to exp(A-B) @@ -179,17 +227,16 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kDivide); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(op::Exp(param0), op::Exp(param1))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kExp); - EXPECT_EQ(root->operand_count(), 1); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kSubtract); - EXPECT_EQ(root->operand(0)->operand(0), param0); - EXPECT_EQ(root->operand(0)->operand(1), param1); + + EXPECT_THAT(computation->root_instruction(), + op::Exp(op::Subtract(param0, param1))); } // Test that ln(exp(A)) is simplified to A @@ -205,14 +252,14 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kLog); + + EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); - EXPECT_EQ(root, param0); + + EXPECT_EQ(computation->root_instruction(), param0); } // Test that ln(exp(A)/exp(B)) is simplified to A-B @@ -234,15 +281,15 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kLog); + + EXPECT_THAT(computation->root_instruction(), + op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); - EXPECT_EQ(root->operand(0), param0); - EXPECT_EQ(root->operand(1), param1); + + EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } // Test that pow(A, 0) where A is a scalar is simplified to the scalar @@ -259,11 +306,15 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kConstant); + EXPECT_THAT(root, op::Constant()); EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); } @@ -280,11 +331,15 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_THAT(root, op::Broadcast()); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); @@ -306,12 +361,14 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); - EXPECT_EQ(root, param0); + + EXPECT_EQ(computation->root_instruction(), param0); } // Test that pow(A, 2) is simplified to A*A. @@ -327,13 +384,14 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); - EXPECT_EQ(root->operand(0), param0); - EXPECT_EQ(root->operand(1), param0); + + EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } // Test that pow(A, -1) is simplified to 1/A. @@ -349,15 +407,17 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kDivide); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_THAT(root, op::Divide(op::Constant(), param0)); EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), 1); - EXPECT_EQ(root->operand(1), param0); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -376,12 +436,15 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { auto computation = builder.Build(); auto module = MakeUnique(TestName()); module->AddEntryComputation(std::move(computation)); - HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(op::Broadcast(op::Reshape(op)))); + HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - root = module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); + + EXPECT_THAT(module->entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. @@ -395,103 +458,117 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), input); } -TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) { +// Test that copies are removed. +TEST_F(AlgebraicSimplifierTest, RemoveCopy) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); + HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42); + EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) { +// Test that unary concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), - 42.0f); + EXPECT_THAT(computation->root_instruction(), param0); } -TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) { +// Test that empty operands of concatenates are removed. +TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); - builder.AddInstruction( - HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode()); + EXPECT_THAT( + computation->root_instruction(), + op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(param0, param0, param1)); } -// Test that copies are removed. -TEST_F(AlgebraicSimplifierTest, RemoveCopy) { - Shape r0f32 = ShapeUtil::MakeShape(F32, {}); +// Test a concatenate with only empty operands is removed. +TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { + const int kParamLength = 100; + Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, r0f32, "param0")); - HloInstruction* copy = builder.AddInstruction( - HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* empty_literal = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction* empty_slice = + builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42})); + Shape result_shape = ShapeUtil::MakeShape(F32, {0}); + builder.AddInstruction(HloInstruction::CreateConcatenate( + result_shape, {empty_literal, empty_slice}, 0)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), + op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(param0, computation->root_instruction()); + EXPECT_EQ(computation->root_instruction(), empty_literal); } // Test that a simplification which changes layouts is not performed if layout @@ -511,14 +588,14 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Copy has not been removed. - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } // Test that a simplification which preserves layouts is performed if layout @@ -538,14 +615,14 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - EXPECT_EQ(copy, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Copy has been removed. - EXPECT_EQ(param0, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), param0); } // Test that a reshape which could be replaced with a bitcast is not if @@ -566,14 +643,14 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Reshape is not replaced with a bitcast. - EXPECT_EQ(reshape, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } // Test transforming reshapes to bitcasts under various conditions. @@ -612,22 +689,18 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(transformable_reshape, computation->root_instruction()->operand(0)); - EXPECT_EQ(dimensions_wrong_reshape, - computation->root_instruction()->operand(1)); - EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(transformable_reshape, dimensions_wrong_reshape, + layout_wrong_reshape)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); simplifier.Run(module.get()).ValueOrDie(); // Verify that only the first reshape is replaced. - EXPECT_NE(transformable_reshape, computation->root_instruction()->operand(0)); - EXPECT_EQ(HloOpcode::kBitcast, - computation->root_instruction()->operand(0)->opcode()); - EXPECT_EQ(dimensions_wrong_reshape, - computation->root_instruction()->operand(1)); - EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2)); + EXPECT_THAT( + computation->root_instruction(), + op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { @@ -645,14 +718,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { HloOpcode::kMaximum, movable_reshape, zero)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMaximum); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Reshape(param), zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); + simplifier.Run(module.get()).ValueOrDie(); - EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kMaximum, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Maximum(param, zero))); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { @@ -672,13 +747,14 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { @@ -698,13 +774,14 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Verify that the reshape is replaced. - EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); } TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { @@ -717,23 +794,20 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 1, 2}), param0)); - HloInstruction* reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(reshape2, computation->root_instruction()); - EXPECT_EQ(reshape1, computation->root_instruction()->operand(0)); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } TEST_F(AlgebraicSimplifierTest, TransposesMerged) { @@ -746,25 +820,21 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0})); - HloInstruction* transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(transpose2, computation->root_instruction()); - EXPECT_EQ(transpose1, computation->root_instruction()->operand(0)); + EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kTranspose, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), computation->root_instruction()->dimensions()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); } // Test merging reshape and broadcast. @@ -780,13 +850,14 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Broadcast(op::Reshape(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } // Test merging broadcast and reshape. @@ -802,13 +873,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param0))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_EQ(HloOpcode::kParameter, - computation->root_instruction()->operand(0)->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { @@ -821,11 +893,17 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { @@ -840,12 +918,16 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { auto module = MakeUnique(TestName()); HloComputation* computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); - EXPECT_MATCH(computation->root_instruction()->dimensions(), - testing::VectorMatcher({3})); + + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); + EXPECT_THAT(computation->root_instruction()->dimensions(), + ::testing::ElementsAre(3)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { @@ -860,15 +942,18 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { auto module = MakeUnique(TestName()); HloComputation* computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); - EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode()); + + EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = computation->root_instruction()->dimensions(); EXPECT_EQ(1, broadcast_dims.size()); - EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 || - broadcast_dims[3] == 3); + EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3)); } TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { @@ -881,11 +966,17 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Broadcast(param))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { @@ -908,10 +999,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, NegativePadding) { @@ -951,18 +1045,14 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { return false; }; - EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(computation->root_instruction(), pad); + EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(4, computation->instruction_count()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kSlice); - const HloInstruction* root_operand = - computation->root_instruction()->operand(0); - EXPECT_EQ(root_operand->opcode(), HloOpcode::kPad); - EXPECT_FALSE(has_negative_padding(root_operand)); + EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); + EXPECT_FALSE( + has_negative_padding(computation->root_instruction()->operand(0))); } TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { @@ -976,10 +1066,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { @@ -996,10 +1089,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), op::Slice(param)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_EQ(1, computation->instruction_count()); + + EXPECT_THAT(computation->root_instruction(), param); } TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { @@ -1235,21 +1331,21 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, param0, min_value)); - HloInstruction* max = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, max); + + EXPECT_THAT(computation->root_instruction(), + op::Maximum(op::Minimum(param0, min_value), max_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - ASSERT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar @@ -1265,21 +1361,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for @@ -1296,21 +1392,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kClamp); - EXPECT_EQ(root->operand(0), max_value); - EXPECT_EQ(root->operand(1), param0); - EXPECT_EQ(root->operand(2), min_value); + + EXPECT_THAT(computation->root_instruction(), + op::Clamp(max_value, param0, min_value)); } // Test that min(max(A, non-constant1), non-constant2) is not canonicalized to @@ -1326,17 +1422,21 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { HloInstruction::CreateParameter(2, r0f32, "param2")); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); - HloInstruction* min = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Maximum(param0, max_value), min_value)); } // Test that min(f(max(A, constant1)), constant2) is not transformed to @@ -1354,18 +1454,23 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { r0f32, HloOpcode::kMaximum, param0, max_value)); HloInstruction* fmax = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); - HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); HloModule module(TestName()); auto computation = module.AddEntryComputation(builder.Build()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), + min_value)); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); - root = computation->root_instruction(); - EXPECT_EQ(root, min); + + EXPECT_THAT(computation->root_instruction(), + op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), + min_value)); } // Test that slice(broadcast(/*scalar value*/)) simplifies to a single @@ -1402,8 +1507,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - EXPECT_EQ(scalar_param, root->operand(0)); + EXPECT_THAT(root, op::Broadcast(scalar_param)); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); } @@ -1440,11 +1544,90 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); - EXPECT_EQ(forty_two, root->operand(0)); + EXPECT_THAT(root, op::Broadcast(forty_two)); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); } +// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). +TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* operand = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, pad, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); @@ -1461,10 +1644,39 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kParameter); EXPECT_EQ(a, root); EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); } +TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { + // Dots add computations to the parent module. Test that, when the HloModule's + // computations are updated, then iterator invalidation doesn't occur + // when running on subsequent computations. + Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); + HloComputation::Builder builder(TestName() + ".Dot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + std::unique_ptr dot_computation(builder.Build()); + + HloComputation::Builder call_builder(TestName() + ".Call"); + HloInstruction* zero = call_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); + HloInstruction* one = call_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); + builder.AddInstruction( + HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); + + auto module = MakeUnique(TestName()); + module->AddEmbeddedComputation(std::move(dot_computation)); + module->AddEntryComputation(call_builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index e59fad4e05252ebd54b3a7cecbdf990127a5264c..83759a7a0c62222b81b82b8a0f8e0396a8f17eff 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -64,8 +64,9 @@ GlobalDataHandle AllocationTracker::RegisterInternal( auto& allocation = FindOrDie(handle_to_allocation_, handle); int ref_count = allocation->ref_count(); CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1; - allocation->increment_ref_count(); + VLOG(2) << "ref_count: " << ref_count << " -> " << + (ref_count + initial_ref_count); + allocation->increment_ref_count(initial_ref_count); } else { handle = next_handle_++; VLOG(2) << "ref_count: " << initial_ref_count; @@ -125,9 +126,7 @@ tensorflow::Status AllocationTracker::DeallocateShape( handle_map.erase(device_memory->opaque()); } - // TODO(b/36256956) Ideally tuple elements could always be distinct buffers. - if (ShapeUtil::IsTuple(shape) && - backend->transfer_manager()->TupleElementsAreDistinctBuffers()) { + if (ShapeUtil::IsTuple(shape)) { // Traverse into tuple recursively deallocating buffers. TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend->stream_executor(device_ordinal)); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index e00768001620275d702c2f96a89d981526ea81a7..ebbf35b6fe87bc7322ccb99cfe8f8eed56de06b3 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -63,10 +63,10 @@ class Allocation { CHECK_GE(ref_count_, 0); return ref_count_; } - void increment_ref_count() { + void increment_ref_count(int inc) { CHECK_GT(ref_count_, 0); - CHECK_LT(ref_count_, INT_MAX); - ++ref_count_; + CHECK_LE(ref_count_, INT_MAX - inc); + ref_count_ += inc; } void decrement_ref_count() { CHECK_GT(ref_count_, 0); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 5c05417c6dcb887b5352d1270c24a4eae62149e3..1913617fecf757a529bbdc803b4227a560c6e1cf 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -41,13 +41,39 @@ namespace se = ::perftools::gputools; namespace xla { +BackendOptions& BackendOptions::set_platform( + perftools::gputools::Platform* platform) { + platform_ = platform; + return *this; +} + +perftools::gputools::Platform* BackendOptions::platform() const { + return platform_; +} + +BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { + number_of_replicas_ = number_of_replicas; + return *this; +} + +int BackendOptions::number_of_replicas() const { return number_of_replicas_; } + +BackendOptions& BackendOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int BackendOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct Backend::EigenThreadPoolWrapper { - explicit EigenThreadPoolWrapper() - : pool(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "XLAEigen", - tensorflow::port::NumSchedulableCPUs())), + explicit EigenThreadPoolWrapper(const int num_threads) + : pool(new tensorflow::thread::ThreadPool(tensorflow::Env::Default(), + "XLAEigen", num_threads)), wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), device(new Eigen::ThreadPoolDevice(wrapper.get(), wrapper->NumThreads())) {} @@ -58,18 +84,21 @@ struct Backend::EigenThreadPoolWrapper { }; /* static */ StatusOr> Backend::CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count) { + const BackendOptions& options) { + int64 replica_count = options.number_of_replicas(); if (replica_count == -1) { legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); replica_count = flags->xla_replicas; } + perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); - std::unique_ptr backend(new Backend( - replica_count, platform, compiler, stream_executors, transfer_manager)); + std::unique_ptr backend( + new Backend(replica_count, platform, compiler, stream_executors, + transfer_manager, options.intra_op_parallelism_threads())); TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool, backend->default_stream_executor())); return std::move(backend); @@ -79,7 +108,9 @@ struct Backend::EigenThreadPoolWrapper { Backend::CreateDefaultBackend() { TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetDefaultPlatform()); - return CreateBackend(platform); + BackendOptions backend_options; + backend_options.set_platform(platform); + return CreateBackend(backend_options); } tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) { @@ -114,7 +145,7 @@ Backend::Backend( int64 replica_count, perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager) + TransferManager* transfer_manager, int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -144,7 +175,11 @@ Backend::Backend( inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( tensorflow::Env::Default(), "xla_inter_op", tensorflow::port::NumSchedulableCPUs())); - intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper()); + const int num_threads = intra_op_parallelism_threads > 0 + ? intra_op_parallelism_threads + : tensorflow::port::NumSchedulableCPUs(); + intra_op_thread_pool_wrapper_.reset( + new EigenThreadPoolWrapper(num_threads)); } } @@ -190,10 +225,17 @@ tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { - if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + if (intra_op_thread_pool_wrapper_ == nullptr) { + return nullptr; + } return intra_op_thread_pool_wrapper_->device.get(); } +tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { + if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr; + return intra_op_thread_pool_wrapper_->pool.get(); +} + StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 9f6829b7d937cec6a67d4016a40506de5df8572d..1068bac2779e9a3dc6c23c0b9fbcc5403fcc2815 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -39,6 +39,31 @@ struct ThreadPoolDevice; namespace xla { +// Options to configure the backend when it is created. +class BackendOptions { + public: + // Set the platform backing the backend, or nullptr for the default platform. + BackendOptions& set_platform(perftools::gputools::Platform* platform); + perftools::gputools::Platform* platform() const; + + // Set the number of replicas to use when compiling replicated + // programs. The default is -1 meaning that the value is read from + // the xla_replicas flag. + BackendOptions& set_number_of_replicas(int number_of_replicas); + int number_of_replicas() const; + + // Sets the thread pool size for parallel execution of an individual operator. + // The default value of -1 will result in initializing the thread pool with + // the number of threads equal to the number of cores in the system. + BackendOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + + private: + perftools::gputools::Platform* platform_ = nullptr; + int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; +}; + // Class which encapsulates an XLA backend. It includes everything necessary // to compile and execute computations on a particular platform. // @@ -53,9 +78,9 @@ class Backend { static constexpr int kInitialStreamsToPool = 8; // Creates a new backend for the given platform with the given number of - // replicas. A value of -1 means to use the flag value. + // replicas. static StatusOr> CreateBackend( - perftools::gputools::Platform* platform, int64 replica_count = -1); + const BackendOptions& options); // Creates a backend for the default platform. The default platform is defined // in PlatformUtil. @@ -150,6 +175,7 @@ class Backend { // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + tensorflow::thread::ThreadPool* eigen_intra_op_thread_pool() const; // Resets the devices associated with this backend. Status ResetDevices(); @@ -160,7 +186,7 @@ class Backend { Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager); + TransferManager* transfer_manager, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index e2b550fc022610c72aa312281727c9c2aea66388..ccb84b026e8782bdf76006a484ac5077a616fb5f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -41,6 +41,8 @@ limitations under the License. namespace xla { +using ::tensorflow::gtl::FlatMap; +using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; @@ -394,8 +396,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - tensorflow::gtl::FlatSet thread_local_set; - tensorflow::gtl::FlatSet global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -487,21 +489,10 @@ Status GatherComputationsByAllocationType( StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, LogicalBuffer::SizeFunction buffer_size, int64 alignment, - bool colocate_related_buffers, - const std::vector* hlos_to_allocate) { + bool allow_input_output_aliasing) { BufferAssigner assigner(std::move(buffer_size), alignment, - colocate_related_buffers); - return assigner.CreateAssignment(module, std::move(hlo_ordering), - hlos_to_allocate); -} - -/* static */ -StatusOr> BufferAssigner::Run( - const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment) { - return BufferAssigner::Run(module, std::move(hlo_ordering), - std::move(buffer_size), alignment, - /*colocate_related_buffers=*/true); + allow_input_output_aliasing); + return assigner.CreateAssignment(module, std::move(hlo_ordering)); } bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, @@ -535,6 +526,28 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, << " may interfere with " << buffer; return false; } + // Copy instruction don't share a buffer with their input operand. + if (buffer.instruction()->IsUserOf(assigned_buffer.instruction()) && + buffer.instruction()->opcode() == HloOpcode::kCopy) { + VLOG(4) << "Can't assign: assignee " << assigned_buffer + << " is used at copy instruction " << buffer; + return false; + } + } + + if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { + HloComputation* entry_computation = + assignment->module_->entry_computation(); + for (auto param : entry_computation->parameter_instructions()) { + for (auto& param_buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + param)) { + if (assignment->liveness().MayInterfere(*param_buffer, buffer)) { + VLOG(4) << "Can't assign: Parameter interference with result"; + return false; + } + } + } } // If the buffer is live out of the computation then it should only be @@ -554,31 +567,28 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, - const tensorflow::gtl::FlatSet& colocated_buffers, - const tensorflow::gtl::FlatSet& - colocated_allocations, + const FlatSet& colocated_buffers, + const FlatSet& colocated_allocations, + FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. std::vector sorted_buffers; for (auto& instruction : computation->instructions()) { - if (hlos_to_allocate == nullptr || - hlos_to_allocate->count(instruction.get()) > 0) { - // Add all buffers which this instruction defines. Instruction which don't - // define buffers (eg, bitcast which just forwards a pointer) don't need - // any allocations. - for (const LogicalBuffer* buffer : - assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction.get())) { - sorted_buffers.push_back(buffer); - } + // Add all buffers which this instruction defines. Instruction which don't + // define buffers (eg, bitcast which just forwards a pointer) don't need + // any allocations. + for (const LogicalBuffer* buffer : + assignment->points_to_analysis().GetBuffersDefinedByInstruction( + instruction.get())) { + sorted_buffers.push_back(buffer); } } // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - tensorflow::gtl::FlatMap post_order_position; + FlatMap post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -588,9 +598,16 @@ Status BufferAssigner::AssignBuffersForComputation( // If there is a sequential instruction ordering, we'll delay assignment of // temp buffers until after the main assignment loop. const BufferLiveness& liveness = assignment->liveness(); - const std::vector* sequential_order = - liveness.hlo_ordering().SequentialOrder(*computation); - tensorflow::gtl::FlatSet unassigned_temp_buffers; + const bool has_sequential_order = + liveness.hlo_ordering().SequentialOrder(*computation) != nullptr; + if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { + // Every sequential computation must get an entry in the + // buffers_to_assign_sequentially map, even if we end up with an empty set + // of buffers. This ensures we can correctly determine whether to run + // whole-module heap simulation. + buffers_to_assign_sequentially->emplace(computation, + FlatSet()); + } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers // first for simplicity. This means any previously created BufferAllocation is @@ -609,7 +626,7 @@ Status BufferAssigner::AssignBuffersForComputation( // important reuse case where an elementwise instruction reuses one of its // operand's buffer. This improves locality. std::sort(sorted_buffers.begin(), sorted_buffers.end(), - [this, sequential_order, &liveness, &post_order_position]( + [this, has_sequential_order, &liveness, &post_order_position]( const LogicalBuffer* a, const LogicalBuffer* b) { // Primary sort is by decreasing buffer size. const int64 a_size = buffer_size_(*a); @@ -619,7 +636,7 @@ Status BufferAssigner::AssignBuffersForComputation( } // Otherwise live out buffers come before others, if the // instructions are sequentially ordered. - if (sequential_order != nullptr) { + if (has_sequential_order) { const bool a_live_out = liveness.MaybeLiveOut(*a); const bool b_live_out = liveness.MaybeLiveOut(*b); if (a_live_out != b_live_out) { @@ -756,7 +773,7 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!assignment->HasAllocation(*buffer) && sequential_order != nullptr && + if (!assignment->HasAllocation(*buffer) && has_sequential_order && !liveness.MaybeLiveOut(*buffer)) { // There is a sequential instruction ordering, so we delay assignment of // temp buffers until after the loop. We do this right before we decide to @@ -768,7 +785,7 @@ Status BufferAssigner::AssignBuffersForComputation( // for the definition of temp buffers. CHECK(!is_entry_parameter) << *buffer; CHECK(!is_thread_local) << *buffer; - unassigned_temp_buffers.insert(buffer); + (*buffers_to_assign_sequentially)[computation].insert(buffer); VLOG(3) << "Delaying assignment of temp buffer: " << *buffer; continue; } @@ -782,27 +799,68 @@ Status BufferAssigner::AssignBuffersForComputation( } } - if (!unassigned_temp_buffers.empty()) { - TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( - *sequential_order, unassigned_temp_buffers, *computation, assignment)); - } return Status::OK(); } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const std::vector& sequence, - const tensorflow::gtl::FlatSet& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment) { + const FlatMap>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment_)), - sequence, computation, - assignment->points_to_analysis(), buffer_size_, - &buffers_to_assign)); + const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + if (run_whole_module_heap_simulation) { + // Run the heap simulation over the whole module. This reduces memory usage, + // since buffers for kCall and kWhile sub-computations are only live for the + // duration of their calling instructions. + VLOG(1) << "Running whole-module heap simulation"; + SequentialHloOrdering::HloModuleSequence module_sequence; + FlatSet all_buffers_to_assign; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + module_sequence[computation] = *instruction_sequence; + all_buffers_to_assign.insert(buffers_to_assign.begin(), + buffers_to_assign.end()); + } + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + assignment->module(), module_sequence, + assignment->points_to_analysis(), buffer_size_, + &all_buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } else { + // Run the heap-simulation on a per-computation basis. Buffers for + // sub-computations are assigned disjoint BufferAllocations, assuming the + // worst-case that they may all be live concurrently. + VLOG(1) << "Running per-computation heap simulation"; + for (const auto& pair : buffers_to_assign_sequentially) { + const HloComputation* computation = pair.first; + const FlatSet& buffers_to_assign = pair.second; + const std::vector* instruction_sequence = + hlo_ordering.SequentialOrder(*computation); + CHECK(instruction_sequence != nullptr) << computation->name(); + TF_ASSIGN_OR_RETURN( + const HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique( + MakeUnique(alignment_)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), buffer_size_, + &buffers_to_assign)); + AssignBuffersFromHeapSimulator(result, assignment); + } + } + return Status::OK(); +} + +void BufferAssigner::AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, BufferAssignment* assignment) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = result.fragmentation_size; @@ -811,8 +869,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( result.fragmentation_size; } - // Use the results of the heap simulator to create one allocation per - // computation, with LogicalBuffers packed to specific offsets. BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true); for (const auto& buffer_chunk : result.chunk_map) { @@ -820,7 +876,6 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } - return Status::OK(); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -881,40 +936,152 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } +// Conceptually the same as AddSetToColocatedBufferSets, but specific to the +// colocated buffers for while instructions. 'colocated_set' contains the +// buffers for a single while instruction that must be colocated. The idea here +// is to apply a memory-saving heuristic for separate while instructions whose +// buffers are disjoint in liveness, by using the colocation mechanism to force +// buffer sharing. This often reduces memory for multi-layer RNNs. +// +// TODO(b/32491382): We should be able to remove this heuristic after we +// implement module-level liveness analysis, which would let us directly detect +// buffer sharing opportunities between the while instruction buffer and the +// buffers from the predicate and body computation, as well as sharing across +// different while instructions. +void BufferAssigner::AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + std::vector* colocated_buffer_sets) { + CHECK(!colocated_set.empty()); + const TuplePointsToAnalysis& points_to_analysis = + buffer_liveness.points_to_analysis(); + + // Parallel while loops cannot safely share colocated buffer sets. + if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) { + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + return; + } + + // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets + // are added in postorder over computations and instructions. + const int64 init_buffer_size = buffer_size_(*while_init_buffer); + for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { + const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; + + // Skip predecessor sets not associated with while loops. + if (std::all_of(predecessor_set.begin(), predecessor_set.end(), + [](const LogicalBuffer* buffer) { + return buffer->instruction()->opcode() != + HloOpcode::kWhile; + })) { + continue; + } + + // Skip predecessor sets already associated with 'while_hlo'. + if (std::any_of(predecessor_set.begin(), predecessor_set.end(), + [&while_hlo](const LogicalBuffer* buffer) { + return buffer->instruction() == while_hlo; + })) { + continue; + } + + // Build vector of predecessor while result and init buffers, which are + // checked for liveness interference below. We must check both the result + // and init buffers because they're aliased together, but + // TuplePointsToAnalysis is unaware of this aliasing. + std::vector predecessor_while_buffers; + for (const LogicalBuffer* buffer : predecessor_set) { + const HloInstruction* instruction = buffer->instruction(); + if (instruction->opcode() == HloOpcode::kWhile && + buffer_size_(*buffer) == init_buffer_size && + instruction->parent() == &computation) { + predecessor_while_buffers.push_back(buffer); + // Add the init buffer at the same index, which must also exist in the + // predecessor set, and must be unambiguous. + const PointsToSet& init_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + const std::vector& init_buffers = + init_points_to.element(buffer->index()); + CHECK_EQ(init_buffers.size(), 1); + CHECK_GT(predecessor_set.count(init_buffers[0]), 0); + predecessor_while_buffers.push_back(init_buffers[0]); + } + } + if (predecessor_while_buffers.empty()) { + continue; + } + + // Skip predecessor set if the live range of any predecessor buffers + // overlaps with 'while_init_buffer'. Note that tuple element buffer + // forwarding can cause the same buffer to appear on both sides of the + // interference comparison below. + if (std::any_of( + predecessor_while_buffers.begin(), predecessor_while_buffers.end(), + [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { + return while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer); + })) { + continue; + } + + // All our checks have passed; merge 'predecessor_set' with 'colocated_set', + // and add the merged set to 'colocated_buffer_sets'. This forces the + // colocation of buffers across different while instructions. + FlatSet unique; + unique.insert(predecessor_set.begin(), predecessor_set.end()); + unique.insert(colocated_set.begin(), colocated_set.end()); + std::vector merged_set(unique.begin(), unique.end()); + AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets); + return; + } + + // Failed to merge into predecessor set; add 'colocated_set' as-is. + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); +} + namespace { + // Checks that points-to set of 'instruction' is unambiguous and distinct // (ensured by CopyInsertion), then adds the buffer from the points-to set at // 'index' to 'colocated_set'. -void AddBufferToColocatedSet(const HloInstruction* instruction, - const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); CHECK(!points_to.IsAmbiguous()); CHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); } + } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated // in the same allocation (currently just supports kWhile and kCall). void BufferAssigner::BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, std::vector* colocated_buffer_sets) { - for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + const TuplePointsToAnalysis& points_to_analysis = + buffer_liveness.points_to_analysis(); + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + for (const HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { const HloOpcode opcode = instruction->opcode(); if (opcode == HloOpcode::kWhile) { - HloInstruction* while_hlo = instruction.get(); + const HloInstruction* while_hlo = instruction; TF_CHECK_OK(ShapeUtil::ForEachSubshape( while_hlo->shape(), - [this, while_hlo, &points_to_analysis, colocated_buffer_sets]( - const Shape& /*subshape*/, const ShapeIndex& index) { + [this, while_hlo, &points_to_analysis, &buffer_liveness, + computation, colocated_buffer_sets](const Shape& /*subshape*/, + const ShapeIndex& index) { std::vector colocated_set; // Add while.init. - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); + auto* init_buffer = + AddBufferToColocatedSet(while_hlo->operand(0), index, + points_to_analysis, &colocated_set); // Add while.result. AddBufferToColocatedSet(while_hlo, index, points_to_analysis, &colocated_set); @@ -930,12 +1097,15 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + AddWhileSetToColocatedBufferSets( + colocated_set, init_buffer, while_hlo, *computation, + buffer_liveness, colocated_buffer_sets); return Status::OK(); })); } else if (opcode == HloOpcode::kCall) { - HloInstruction* call_hlo = instruction.get(); - HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction(); + const HloInstruction* call_hlo = instruction; + const HloInstruction* root_hlo = + call_hlo->to_apply()->root_instruction(); TF_CHECK_OK(ShapeUtil::ForEachSubshape( call_hlo->shape(), [this, call_hlo, root_hlo, &points_to_analysis, @@ -961,8 +1131,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet* colocated_buffers, - tensorflow::gtl::FlatSet* colocated_allocations) { + FlatSet* colocated_buffers, + FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; for (const LogicalBuffer* buffer : colocated_buffer_set) { @@ -980,40 +1150,33 @@ void BufferAssigner::AssignColocatedBufferSets( buffer_size_(*buffer)); } colocated_buffers->insert(buffer); + + // Each entry parameter must reside in its own BufferAllocation. As a + // result, it doesn't make sense for entry parameters to appear in a + // colocated buffer set, since the only correct scenario would be a + // degenerate colocated set that only contains the entry parameter. + const HloInstruction* instruction = buffer->instruction(); + const HloComputation* computation = instruction->parent(); + const bool is_entry_parameter = + instruction->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); + CHECK(!is_entry_parameter) + << "allocation: " << *allocation << " instruction: " << *buffer << " " + << instruction->ToString(); } } } StatusOr> BufferAssigner::CreateAssignment( - const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate) { + const HloModule* module, std::unique_ptr hlo_ordering) { TF_ASSIGN_OR_RETURN(std::unique_ptr liveness, BufferLiveness::Run(module, std::move(hlo_ordering))); - std::vector thread_local_computations; - std::vector global_computations; VLOG(1) << "Assigning buffers to module " << module->name(); - if (hlos_to_allocate != nullptr) { - VLOG(3) << "LogicalBuffer assignment restricted to hlos: "; - for (auto hlo : *hlos_to_allocate) { - VLOG(3) << " " << hlo->parent()->name() << "::" << hlo->name(); - } - } - XLA_VLOG_LINES(3, module->ToString()); + XLA_VLOG_LINES(2, module->ToString()); XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( - module, &thread_local_computations, &global_computations)); - - // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to - // AssignBuffersForComputation for fast membership testing. - std::unique_ptr> hlo_set; - if (hlos_to_allocate != nullptr) { - hlo_set = MakeUnique>( - hlos_to_allocate->begin(), hlos_to_allocate->end()); - } - // Can't use MakeUnique because BufferAssignment constructor is private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), alignment_)); @@ -1022,26 +1185,46 @@ StatusOr> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - tensorflow::gtl::FlatSet colocated_buffers; - tensorflow::gtl::FlatSet colocated_allocations; - if (colocate_related_buffers_) { - std::vector colocated_buffer_sets; - BuildColocatedBufferSets(module, assignment->points_to_analysis(), - &colocated_buffer_sets); - AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), - &colocated_buffers, &colocated_allocations); - } + FlatSet colocated_buffers; + FlatSet colocated_allocations; + std::vector colocated_buffer_sets; + BuildColocatedBufferSets(module, assignment->liveness(), + &colocated_buffer_sets); + AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(), + &colocated_buffers, &colocated_allocations); + + std::vector thread_local_computations; + std::vector global_computations; + TF_RETURN_IF_ERROR(GatherComputationsByAllocationType( + module, &thread_local_computations, &global_computations)); + // First assign buffers for global computatations. Temporary buffers for + // sequential computations are collected in 'buffers_to_assign_sequentially'. + FlatMap> + buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/false, hlo_set.get(), - colocated_buffers, colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/false, colocated_buffers, + colocated_allocations, &buffers_to_assign_sequentially, + assignment.get())); } + // Assign buffers with sequential ordering, if any. If all global computations + // are sequential, we can run heap simuation on the whole module, which + // reduces memory usage. + const bool run_whole_module_heap_simulation = + buffers_to_assign_sequentially.size() == global_computations.size(); + TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( + buffers_to_assign_sequentially, run_whole_module_heap_simulation, + assignment.get())); + + // Now assign buffers for thread-local computations. All LogicalBuffers get + // their own BufferAllocation. for (auto* computation : thread_local_computations) { TF_RET_CHECK(computation != module->entry_computation()); TF_RETURN_IF_ERROR(AssignBuffersForComputation( - computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers, - colocated_allocations, assignment.get())); + computation, /*is_thread_local=*/true, colocated_buffers, + colocated_allocations, /*buffers_to_assign_sequentially=*/nullptr, + assignment.get())); } // Mark all buffers which may be live out of the entry computation as diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index b82acb19b3488884bdc8d2d5c4a1524ac165676a..4b8b2cb9c4b5c27ad48e00e7f73635ffd6207882 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -308,6 +309,9 @@ class BufferAssignment { return liveness_->points_to_analysis(); } + // Returns the BufferLiveness object used to construct this assignment. + const BufferLiveness& liveness() const { return *liveness_; } + string ToString() const; // Statistics for the assignment. Values initialized to -1 are not always @@ -354,8 +358,8 @@ class BufferAssignment { void AddAssignment(BufferAllocation* allocation, const LogicalBuffer& buffer, int64 offset, int64 size); - // Returns the BufferLiveness object used to construct this assignment. - const BufferLiveness& liveness() { return *liveness_; } + // Returns the HloModule used to construct this assignment. + const HloModule& module() { return *module_; } // Convenience function which returns the PointsToSet for the given // instruction. Extracted from the liveness object. @@ -396,58 +400,55 @@ class BufferAssigner { // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size is a function // which returns the size of a LogicalBuffer. Alignment is the the minimum - // alignment of any buffer. If hlos_to_allocate is not null then only - // instructions in this vector are considered for buffer assignment. If - // hlos_to_allocate is null then all instructions are considered. If - // 'colocate_related_buffers' is true, related LogicalBuffers will be - // colocated in the same allocation (i.e buffers for while result will share - // an allocation with buffers related to that same while instruction: init - // operand, condition/body parameter and body result). + // alignment of any buffer. allow_input_output_aliasing specifies whether + // input buffer are allowed to be reused as outbut buffers by the client code. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, LogicalBuffer::SizeFunction buffer_size, int64 alignment, - bool colocate_related_buffers, - const std::vector* hlos_to_allocate = nullptr); - - // Overload of Run which uses ShapeUtil::ByteSizeOf to determine buffer size - // and assigns buffers to all HLO instructions in the module. - static StatusOr> Run( - const HloModule* module, std::unique_ptr hlo_ordering, - LogicalBuffer::SizeFunction buffer_size, int64 alignment); + bool allow_input_output_aliasing = false); private: - explicit BufferAssigner(LogicalBuffer::SizeFunction buffer_size, - int64 alignment, bool colocate_related_buffers) + BufferAssigner(LogicalBuffer::SizeFunction buffer_size, int64 alignment, + bool allow_input_output_aliasing) : buffer_size_(std::move(buffer_size)), alignment_(alignment), - colocate_related_buffers_(colocate_related_buffers) {} + allow_input_output_aliasing_(allow_input_output_aliasing) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. StatusOr> CreateAssignment( - const HloModule* module, std::unique_ptr hlo_ordering, - const std::vector* hlos_to_allocate = nullptr); + const HloModule* module, std::unique_ptr hlo_ordering); // Assigns buffers to the instructions in the given computation. "assignment" // is modified to reflect the new buffer assignments. If is_thread_local is // true, then all assigned buffers have the is_thread_local flag set to - // true. If hlos_to_allocate is not null it indicates which HLOs to include in - // buffer assignment. If null, all instructions in the computation are - // included. + // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet* hlos_to_allocate, const tensorflow::gtl::FlatSet& colocated_buffers, const tensorflow::gtl::FlatSet& colocated_allocations, + tensorflow::gtl::FlatMap>* + buffers_to_assign_sequentially, BufferAssignment* assignment); - // Assigns 'buffers_to_assign' assuming the HLO instructions will be executed - // in the given 'sequential_order'. + // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming + // the HLO instructions will be executed in the sequential order given by + // assignment->liveness().hlo_ordering().SequentialOrder. If + // 'run_whole_module_heap_simulation' is true, the heap simulation will be run + // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const std::vector& sequential_order, - const tensorflow::gtl::FlatSet& buffers_to_assign, - const HloComputation& computation, BufferAssignment* assignment); + const tensorflow::gtl::FlatMap< + const HloComputation*, + tensorflow::gtl::FlatSet>& + buffers_to_assign_sequentially, + bool run_whole_module_heap_simulation, BufferAssignment* assignment); + + // Uses the results of the heap simulator to create a single allocation, with + // LogicalBuffers packed to specific offsets. + void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, + BufferAssignment* assignment); // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. @@ -465,7 +466,7 @@ class BufferAssigner { // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' // which should be colocated in the same buffer allocation. void BuildColocatedBufferSets( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloModule* module, const BufferLiveness& buffer_liveness, std::vector* colocated_buffer_sets); // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the @@ -482,7 +483,13 @@ class BufferAssigner { const std::vector& colocated_set, std::vector* colocated_buffer_sets); - const HloModule* module_; + // Conceptually the same as AddSetToColocatedBufferSets, but specific to the + // colocated buffers for while instructions. + void AddWhileSetToColocatedBufferSets( + const std::vector& colocated_set, + const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const HloComputation& computation, const BufferLiveness& buffer_liveness, + std::vector* colocated_buffer_sets); // Function which returns the buffer size for a given logical buffer (shape). LogicalBuffer::SizeFunction buffer_size_; @@ -490,8 +497,9 @@ class BufferAssigner { // Minimum alignment of any buffer. int64 alignment_; - // Indicates whether related buffers should share the same buffer allocation. - const bool colocate_related_buffers_; + // If true, buffer assignments assumes that input parameter buffers and output + // buffers can be shared if their sizes match. + bool allow_input_output_aliasing_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bb7342d5081af32c9882311af8dddf08c115becc..ac1d769010c55ee4430554abe3205391bee5ebf1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -22,12 +22,18 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -208,30 +214,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Returns true if the buffers assigned to instructions in "a" are distinct - // from the buffers assigned to those in "b" (ie, intersection is empty). - bool BuffersDistinct(const std::vector& a, - const std::vector& b, - const BufferAssignment& assignment) { - std::set a_slices; - for (const HloInstruction* instruction : a) { - if (assignment.HasTopLevelAllocation(instruction)) { - a_slices.insert( - assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie()); - } - } - - for (const HloInstruction* instruction : b) { - if (assignment.HasTopLevelAllocation(instruction)) { - if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) - .ConsumeValueOrDie())) { - return false; - } - } - } - return true; - } - // Computation tracker for nested computations. ComputationTracker computation_tracker_; @@ -246,6 +228,30 @@ class BufferAssignmentTest : public HloTestBase { Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); }; +// Returns true if the buffers assigned to instructions in "a" are distinct +// from the buffers assigned to those in "b" (ie, intersection is empty). +static bool BuffersDistinct(const std::vector& a, + const std::vector& b, + const BufferAssignment& assignment) { + std::set a_slices; + for (const HloInstruction* instruction : a) { + if (assignment.HasTopLevelAllocation(instruction)) { + a_slices.insert( + assignment.GetUniqueTopLevelSlice(instruction).ConsumeValueOrDie()); + } + } + + for (const HloInstruction* instruction : b) { + if (assignment.HasTopLevelAllocation(instruction)) { + if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction) + .ConsumeValueOrDie())) { + return false; + } + } + } + return true; +} + // Tests a computation consisting of a single scalar constant node. TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); @@ -850,8 +856,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { EXPECT_FALSE(map_root_alloc.maybe_live_out()); EXPECT_TRUE(map_root_alloc.is_thread_local()); - // Allocations for the call computation should not be thread-local and not - // live-out. + // Allocations for the call computation should not be thread-local. auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param); EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter()); EXPECT_FALSE(call_param_alloc.maybe_live_out()); @@ -859,7 +864,6 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root); EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter()); - EXPECT_FALSE(call_root_alloc.maybe_live_out()); EXPECT_FALSE(call_root_alloc.is_thread_local()); // Entry computation allocations can be marked liveout and @@ -1144,12 +1148,12 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { // should include the slices of both of the elements in the parameters. auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); EXPECT_EQ(2, element_slices.size()); - EXPECT_MATCH(testing::SetToVec(element_slices), - testing::UnorderedMatcher( - assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) - .ConsumeValueOrDie(), - assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) - .ConsumeValueOrDie())); + EXPECT_THAT(element_slices, + ::testing::UnorderedElementsAre( + assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) + .ConsumeValueOrDie(), + assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) + .ConsumeValueOrDie())); } // TODO(b/34669761): Remove this test when buffers are allowed to share @@ -1245,6 +1249,257 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } } -} // namespace +class WhileBufferAssignmentTest : public HloTestBase { + protected: + std::unique_ptr BuildWhileConditionComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto ten = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); + return builder.Build(); + } + + std::unique_ptr BuildWhileBodyComputation( + const string& name) { + auto builder = HloComputation::Builder(name); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto input = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 0)); + auto weights = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + auto output = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kMultiply, input, weights)); + builder.AddInstruction( + HloInstruction::CreateTuple({input, weights, output})); + return builder.Build(); + } + + std::unique_ptr RunBufferAssignment(HloModule* module, + int64 alignment = 1) { + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + return BufferAssigner::Run( + module, MakeUnique(module, sequence), + ByteSizeOf, alignment) + .ConsumeValueOrDie(); + } + + static int64 ByteSizeOf(const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); + } + + Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); + Shape loop_state_shape_ = + ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); +}; +static void RunCopyInsertion(HloModule* module) { + CopyInsertion copy_insertion; + EXPECT_IS_OK(copy_insertion.Run(module).status()); +} + +TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "weights1")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + auto input1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + // While instruction 'while0' has no predecessor while instructions with + // which to share allocations. + + // While instruction 'while1' can share allocations with the following + // buffers: + // *) while0[2], while1[0] + // *) while0[1], while1[1] + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); +} + +TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + // while0 and while1 buffers should be completely aligned. + EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {0}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {1}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); + EXPECT_EQ(assignment->GetUniqueSlice(while0, {2}).ConsumeValueOrDie(), + assignment->GetUniqueSlice(while1, {2}).ConsumeValueOrDie()); +} + +TEST_F(BufferAssignmentTest, TwoCalls) { + auto module = MakeUnique(TestName()); + Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); + HloComputation* sub_computation; + { + auto builder = HloComputation::Builder(TestName() + "_sub_comp"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param")); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); + sub_computation = module->AddEmbeddedComputation(builder.Build(add)); + } + auto builder = HloComputation::Builder(TestName()); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + auto call1 = builder.AddInstruction( + HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); + auto call2 = builder.AddInstruction( + HloInstruction::CreateCall(r0f32, {constant3}, sub_computation)); + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call1, constant2)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, call2, add1)); + module->AddEntryComputation(builder.Build(add2)); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + } + + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); +} + +// Test buffer assignment for while nodes with multiple uses. +// TODO(b/37245345): Fix buffer assignment for this case. +TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); + + auto get0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto get1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 736f227aa423120ecb4a5e82824defac2d345b2e..d69a84cd0e3ffffad32b89afc726a31b175e47c5 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -45,9 +45,7 @@ StatusOr> BufferLiveness::Run( } tensorflow::Status BufferLiveness::Analyze() { - TF_ASSIGN_OR_RETURN(points_to_analysis_, - TuplePointsToAnalysis::Run( - module_, /*include_loop_fusion_instructions=*/true)); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto& computation : module_->computations()) { // Gather all instructions whose buffers might alias other instructions into // the set aliased_buffers_. This includes those contained as a tuple @@ -117,11 +115,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction' // and 'b.instruction' emit the same shape/layout, and 'b.instruction' meets - // one of following qualifications: - // *) Is element-wise. - // *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where - // the singleton use of 'a' at 'a.index' is the fused root at operand 0. - // *) Use of 'operand' is DynamicUpdateSlice at operand index 0. + // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), @@ -133,10 +127,30 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, return true; } +namespace { +bool IsEntryParameter(const HloInstruction* instruction) { + const HloComputation* computation = instruction->parent(); + return instruction->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); +} +} // namespace + bool BufferLiveness::MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const { - return (!live_range_strictly_before(a, b) && - !live_range_strictly_before(b, a)); + // Entry parameters live at the entry of the execution, thus always interfere + // with all other instructions executing before them in the ordering. + const HloInstruction* a_instruction = a.instruction(); + const HloInstruction* b_instruction = b.instruction(); + if (IsEntryParameter(a_instruction) && + hlo_ordering_->ExecutesBefore(b_instruction, a_instruction)) { + return true; + } + if (IsEntryParameter(b_instruction) && + hlo_ordering_->ExecutesBefore(a_instruction, b_instruction)) { + return true; + } + // Buffers without disjoint liveness may interfere. + return !live_range_strictly_before(a, b) && !live_range_strictly_before(b, a); } bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const { diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index e71b98298b344b5689785bfa67a8bea54e0248e3..bee9a351f5df00aea6178fab4fd0e222ff9e9a99 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -92,6 +92,12 @@ class BufferLivenessTest : public HloTestBase { GetBuffer(liveness, instruction, /*index=*/{})); } + std::unique_ptr BuildDummyComputation() { + auto builder = HloComputation::Builder(TestName() + "_dummy"); + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + return builder.Build(); + } + const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42}); }; @@ -118,12 +124,17 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { MakeUnique(module.get())) .ConsumeValueOrDie(); - // No buffers should interfere. EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); + + // No buffers should interfere. EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log)); - EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp)); // Buffers should interfere with itself. EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp)); @@ -135,18 +146,69 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log)); } +TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { + // Two entry params, which interfere with each other. + // + // param0 --> negate ---------------\ + // param1 --> exp --> add + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, vec_, "param1")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp)); + + auto module = MakeUnique(TestName()); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + auto liveness = BufferLiveness::Run( + module.get(), + MakeUnique(module.get(), sequence)) + .ConsumeValueOrDie(); + + // Entry parameters interfere as if they are defined simultaneously at + // the very beginning. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add)); + + // Negate and exp still interfere. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + + // But {negate, add} and {exp, add} don't interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); +} + TEST_F(BufferLivenessTest, NonElementwiseOperand) { - // A chain of operations with one elementwise and one non-elementwise. The + // A chain of operations with two elementwise and one non-elementwise. The // elementwise op should not interfere with its operand, while the - // non-elementwise op should interfere. + // non-elementwise op should interfere. Entry params always interfere. // - // param --> negate -> reverse + // param --> exp -> negate -> reverse // auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param)); auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param)); + HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp)); auto reverse = builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0})); @@ -158,10 +220,14 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { MakeUnique(module.get())) .ConsumeValueOrDie(); - // No buffers should interfere. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse)); + + // Negate is elementwise, so doesn't interfere with its operand. + // Reverse is non-elementwise, so does interfere with its operand. + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate)); EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse)); - EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); } TEST_F(BufferLivenessTest, OverlappedBuffers) { @@ -190,8 +256,15 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp)); - EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { @@ -204,8 +277,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { // Sequential order: // param, negate, exp, add // - // Liveness is identical to the DependencyHloOrdering except that 'param' and - // exp no longer interfere. + // Liveness is identical to the DependencyHloOrdering. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); @@ -229,8 +301,15 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); - EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } TEST_F(BufferLivenessTest, TupleLiveOut) { @@ -392,7 +471,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); auto liveness = BufferLiveness::Run(module.get(), @@ -452,7 +532,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); auto liveness = BufferLiveness::Run(module.get(), @@ -524,7 +605,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = MakeUnique(TestName()); - auto* computation = module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(BuildDummyComputation()); + auto* computation = module->AddEmbeddedComputation(builder.Build()); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -546,7 +628,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { BufferLiveness::Run(module.get(), MakeUnique(module.get())) .ConsumeValueOrDie(); - // Return whether or not buffers interfernce is detected between + // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } @@ -651,13 +733,14 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = MakeUnique(TestName()); - module->AddEntryComputation(builder.Build()); + module->AddEntryComputation(BuildDummyComputation()); + module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run(module.get(), MakeUnique(module.get())) .ConsumeValueOrDie(); - // Return whether or not buffers interfernce is detected between + // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index ab3eae2495ec55e8667db86b025f980157517ccc..fa7b2a309525dd80d655e10474c5d49f9da14ea8 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -51,6 +51,22 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { return out; } +CallContext GetInstructionCallContext(const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kWhile: + return CallContext::kSequential; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + return CallContext::kParallel; + default: + return CallContext::kNone; + } +} + string CallSite::ToString() const { return StrCat(instruction()->name(), " calls in context ", CallContextToString(context()), ": ", @@ -82,32 +98,12 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { } } -namespace { - -CallContext GetInstructionCallContext(const HloInstruction* instruction) { - switch (instruction->opcode()) { - case HloOpcode::kCall: - case HloOpcode::kWhile: - return CallContext::kSequential; - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kFusion: - return CallContext::kParallel; - default: - return CallContext::kNone; - } -} - -} // namespace - -Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { - TF_RET_CHECK(instruction->parent() == computation()); +void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { + CHECK_EQ(instruction->parent(), computation()); const CallContext context = GetInstructionCallContext(instruction); if (!instruction->called_computations().empty()) { - TF_RET_CHECK(context == CallContext::kSequential || - context == CallContext::kParallel); + CHECK(context == CallContext::kSequential || + context == CallContext::kParallel); callsite_instructions_.insert({instruction, callsites_.size()}); callsites_.push_back( CallSite(instruction, instruction->called_computations(), context)); @@ -120,22 +116,21 @@ Status CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { } } } - return Status::OK(); } CallGraph::CallGraph(const HloModule* module) : module_(module) {} -StatusOr CallGraph::GetNode( +const CallGraphNode& CallGraph::GetNode( const HloComputation* computation) const { auto it = node_indices_.find(computation); - TF_RET_CHECK(it != node_indices_.end()); - return &nodes_[it->second]; + CHECK(it != node_indices_.end()); + return nodes_[it->second]; } -StatusOr CallGraph::GetNode(const HloComputation* computation) { +CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { auto it = node_indices_.find(computation); - TF_RET_CHECK(it != node_indices_.end()); - return &nodes_[it->second]; + CHECK(it != node_indices_.end()); + return nodes_[it->second]; } namespace { @@ -158,17 +153,17 @@ CallContext UnionContexts(CallContext a, CallContext b) { } // namespace -Status CallGraph::SetCallContexts() { +void CallGraph::SetCallContexts() { std::queue worklist; // Initialize worklist with all roots of the call graph (computations without // callers). for (const std::unique_ptr& computation : module_->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); - if (node->callers().empty()) { - node->set_context(CallContext::kSequential); - worklist.push(node); + CallGraphNode& node = GetNode(computation.get()); + if (node.callers().empty()) { + node.set_context(CallContext::kSequential); + worklist.push(&node); } } @@ -178,7 +173,7 @@ Status CallGraph::SetCallContexts() { for (const CallSite& callsite : node->callsites()) { for (const HloComputation* callee : callsite.called_computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, GetNode(callee)); + CallGraphNode& callee_node = GetNode(callee); // Update context of callee computation based on the callsite and its // current context. @@ -186,16 +181,16 @@ Status CallGraph::SetCallContexts() { if (callsite.context() == CallContext::kParallel) { context_to_add = CallContext::kParallel; } else { - TF_RET_CHECK(callsite.context() == CallContext::kSequential); + CHECK_EQ(callsite.context(), CallContext::kSequential); context_to_add = node->context(); } CallContext new_context = - UnionContexts(context_to_add, callee_node->context()); + UnionContexts(context_to_add, callee_node.context()); - if (new_context != callee_node->context()) { + if (new_context != callee_node.context()) { // Context of computation has been changed so add node to worklist. - callee_node->set_context(new_context); - worklist.push(callee_node); + callee_node.set_context(new_context); + worklist.push(&callee_node); } } } @@ -204,14 +199,12 @@ Status CallGraph::SetCallContexts() { // No node should have a kNone calling context. for (const std::unique_ptr& computation : module_->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * node, GetNode(computation.get())); - TF_RET_CHECK(node->context() != CallContext::kNone); + CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone); } - return Status::OK(); } /* static */ -StatusOr> CallGraph::Build(const HloModule* module) { +std::unique_ptr CallGraph::Build(const HloModule* module) { // Constructor for CallGraph is private so MakeUnique can't be used. auto call_graph = WrapUnique(new CallGraph(module)); @@ -223,56 +216,51 @@ StatusOr> CallGraph::Build(const HloModule* module) { module->computations()) { auto it_added = call_graph->node_indices_.insert( {computation.get(), call_graph->nodes_.size()}); - // All computation should be unique, so the computation should not already + // All computations should be unique, so the computation should not already // exist in the map. - TF_RET_CHECK(it_added.second); + CHECK(it_added.second); call_graph->nodes_.emplace_back(computation.get()); // Add all callsites in this computation. for (const std::unique_ptr& instruction : computation->instructions()) { - TF_RETURN_IF_ERROR(call_graph->nodes_.back().AddCallSiteForInstruction( - instruction.get())); + call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get()); } } // Add caller callsites to each node. for (const std::unique_ptr& computation : module->computations()) { - TF_ASSIGN_OR_RETURN(CallGraphNode * caller_node, - call_graph->GetNode(computation.get())); - for (const CallSite& callsite : caller_node->callsites()) { + for (const CallSite& callsite : + call_graph->GetNode(computation.get()).callsites()) { for (auto* callee : callsite.called_computations()) { // Add caller callsites. - TF_ASSIGN_OR_RETURN(CallGraphNode * callee_node, - call_graph->GetNode(callee)); - callee_node->AddCallerCallSite(callsite); + call_graph->GetNode(callee).AddCallerCallSite(callsite); } } } - TF_RETURN_IF_ERROR(call_graph->SetCallContexts()); - + call_graph->SetCallContexts(); XLA_VLOG_LINES(1, call_graph->ToString()); - return std::move(call_graph); + return call_graph; } Status CallGraph::VisitNodesInternal( - const VisitorFunction& visitor_func, const CallGraphNode* node, + const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const { - auto pair = visited->insert(node); + auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. return Status::OK(); } - for (const HloComputation* computation : node->callees()) { - TF_ASSIGN_OR_RETURN(const CallGraphNode* callee_node, GetNode(computation)); - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, callee_node, visited)); + for (const HloComputation* computation : node.callees()) { + TF_RETURN_IF_ERROR( + VisitNodesInternal(visitor_func, GetNode(computation), visited)); } - return visitor_func(*node); + return visitor_func(node); } Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, @@ -282,14 +270,13 @@ Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { if (node.callers().empty()) { - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, &node, &visited)); + TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited)); } } } else { // Traverse only from the entry computation. - TF_ASSIGN_OR_RETURN(const CallGraphNode* entry_node, - GetNode(module_->entry_computation())); - TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, entry_node, &visited)); + TF_RETURN_IF_ERROR(VisitNodesInternal( + visitor_func, GetNode(module_->entry_computation()), &visited)); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index e2fed044c88008d0a7e43f0166d397627ed72267..7f9990f06d4fee4c52fa516fc2f6031f5dab2bb9 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -54,6 +53,8 @@ enum class CallContext { string CallContextToString(CallContext context); std::ostream& operator<<(std::ostream& out, const CallContext& context); +CallContext GetInstructionCallContext(const HloInstruction* instruction); + // Represents an HLO instruction which calls one or more computations. class CallSite { public: @@ -136,7 +137,7 @@ class CallGraphNode { // If instruction calls any computations adds a call site for this instruction // to the call graph node. If the instruction calls no computations then no // call site is added. - Status AddCallSiteForInstruction(HloInstruction* instruction); + void AddCallSiteForInstruction(HloInstruction* instruction); // Computation represented by this call graph node. HloComputation* computation_; @@ -172,12 +173,11 @@ class CallGraph { using VisitorFunction = std::function; // Builds and returns a call graph for the given HLO module. - static StatusOr> Build(const HloModule* module); + static std::unique_ptr Build(const HloModule* module); // Returns the node associated with the given computation. - StatusOr GetNode( - const HloComputation* computation) const; - StatusOr GetNode(const HloComputation* computation); + const CallGraphNode& GetNode(const HloComputation* computation) const; + CallGraphNode& GetNode(const HloComputation* computation); // Returns the vector of all nodes in the call graph. const std::vector& nodes() const { return nodes_; } @@ -195,14 +195,14 @@ class CallGraph { CallGraph(const HloModule* module); // Sets the call contexts for every node in the graph. - Status SetCallContexts(); + void SetCallContexts(); // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in // 'visited'. Status VisitNodesInternal( - const VisitorFunction& visitor_func, const CallGraphNode* node, + const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const; // The HLO module represented by this call graph. diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 65900fd4f86cd07d5d956da0df429d30fcdf7561..ab0ea47d024d871be88bfcab957810deb1ecac99 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" @@ -28,7 +29,7 @@ limitations under the License. namespace xla { namespace { -using testing::UnorderedMatcher; +using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: @@ -60,14 +61,15 @@ class CallGraphTest : public HloTestBase { // Build and return a computation which takes a scalar and calls (kCall) the // given computation with value 'callsites' number of times. std::unique_ptr MakeCallingComputation( - HloComputation* map_computation, int64 callsites) { - HloComputation::Builder builder(TestName() + ".CallingComputation"); + HloComputation* callee_computation, int64 callsites, + const string& suffix = ".CallingComputation") { + HloComputation::Builder builder(TestName() + suffix); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* last_value = param0; for (int64 i = 0; i < callsites; ++i) { last_value = builder.AddInstruction(HloInstruction::CreateCall( - kScalarShape, {last_value}, map_computation)); + kScalarShape, {last_value}, callee_computation)); } return builder.Build(); } @@ -93,17 +95,15 @@ TEST_F(CallGraphTest, SingletonComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(1, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* node, - call_graph->GetNode(computation)); - EXPECT_EQ(computation, node->computation()); - EXPECT_TRUE(node->callsites().empty()); - EXPECT_TRUE(node->callees().empty()); - EXPECT_TRUE(node->caller_callsites().empty()); - EXPECT_TRUE(node->callers().empty()); - EXPECT_EQ(CallContext::kSequential, node->context()); + const CallGraphNode& node = call_graph->GetNode(computation); + EXPECT_EQ(computation, node.computation()); + EXPECT_TRUE(node.callsites().empty()); + EXPECT_TRUE(node.callees().empty()); + EXPECT_TRUE(node.caller_callsites().empty()); + EXPECT_TRUE(node.callers().empty()); + EXPECT_EQ(CallContext::kSequential, node.context()); } TEST_F(CallGraphTest, UnreachableComputation) { @@ -115,19 +115,17 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* unreachable_node, - call_graph->GetNode(unreachable_computation)); - EXPECT_EQ(unreachable_computation, unreachable_node->computation()); - EXPECT_EQ(CallContext::kSequential, unreachable_node->context()); + const CallGraphNode& unreachable_node = + call_graph->GetNode(unreachable_computation); + EXPECT_EQ(unreachable_computation, unreachable_node.computation()); + EXPECT_EQ(CallContext::kSequential, unreachable_node.context()); } TEST_F(CallGraphTest, ParallelComputation) { @@ -136,30 +134,27 @@ TEST_F(CallGraphTest, ParallelComputation) { HloModule module(TestName()); HloComputation* map_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - HloComputation* entry_computation = module.AddEmbeddedComputation( + HloComputation* entry_computation = module.AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); - EXPECT_EQ(5, entry_node->callsites().size()); - EXPECT_EQ(1, entry_node->callees().size()); - EXPECT_TRUE(entry_node->caller_callsites().empty()); - EXPECT_TRUE(entry_node->callers().empty()); - - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* map_node, - call_graph->GetNode(map_computation)); - EXPECT_EQ(map_computation, map_node->computation()); - EXPECT_EQ(CallContext::kParallel, map_node->context()); - EXPECT_TRUE(map_node->callsites().empty()); - EXPECT_TRUE(map_node->callees().empty()); - EXPECT_EQ(5, map_node->caller_callsites().size()); - EXPECT_EQ(1, map_node->callers().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(5, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& map_node = call_graph->GetNode(map_computation); + EXPECT_EQ(map_computation, map_node.computation()); + EXPECT_EQ(CallContext::kParallel, map_node.context()); + EXPECT_TRUE(map_node.callsites().empty()); + EXPECT_TRUE(map_node.callees().empty()); + EXPECT_EQ(5, map_node.caller_callsites().size()); + EXPECT_EQ(1, map_node.callers().size()); } TEST_F(CallGraphTest, SequentialComputations) { @@ -168,30 +163,27 @@ TEST_F(CallGraphTest, SequentialComputations) { HloModule module(TestName()); HloComputation* called_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - HloComputation* entry_computation = module.AddEmbeddedComputation( + HloComputation* entry_computation = module.AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); - EXPECT_EQ(3, entry_node->callsites().size()); - EXPECT_EQ(1, entry_node->callees().size()); - EXPECT_TRUE(entry_node->caller_callsites().empty()); - EXPECT_TRUE(entry_node->callers().empty()); - - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* called_node, - call_graph->GetNode(called_computation)); - EXPECT_EQ(called_computation, called_node->computation()); - EXPECT_EQ(CallContext::kSequential, called_node->context()); - EXPECT_TRUE(called_node->callsites().empty()); - EXPECT_TRUE(called_node->callees().empty()); - EXPECT_EQ(3, called_node->caller_callsites().size()); - EXPECT_EQ(1, called_node->callers().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + EXPECT_EQ(3, entry_node.callsites().size()); + EXPECT_EQ(1, entry_node.callees().size()); + EXPECT_TRUE(entry_node.caller_callsites().empty()); + EXPECT_TRUE(entry_node.callers().empty()); + + const CallGraphNode& called_node = call_graph->GetNode(called_computation); + EXPECT_EQ(called_computation, called_node.computation()); + EXPECT_EQ(CallContext::kSequential, called_node.context()); + EXPECT_TRUE(called_node.callsites().empty()); + EXPECT_TRUE(called_node.callees().empty()); + EXPECT_EQ(3, called_node.caller_callsites().size()); + EXPECT_EQ(1, called_node.callers().size()); } TEST_F(CallGraphTest, ContextBothComputations) { @@ -209,34 +201,31 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloInstruction* map = builder.AddInstruction( HloInstruction::CreateMap(kScalarShape, {call}, subcomputation)); HloComputation* entry_computation = - module.AddEmbeddedComputation(builder.Build()); + module.AddEntryComputation(builder.Build()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(2, call_graph->nodes().size()); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - EXPECT_EQ(entry_computation, entry_node->computation()); - EXPECT_EQ(2, entry_node->callsites().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(2, entry_node.callsites().size()); - const CallSite& call_callsite = entry_node->callsites()[0]; + const CallSite& call_callsite = entry_node.callsites()[0]; EXPECT_EQ(call, call_callsite.instruction()); - EXPECT_MATCH(call_callsite.called_computations(), - UnorderedMatcher(subcomputation)); + EXPECT_THAT(call_callsite.called_computations(), + UnorderedElementsAre(subcomputation)); EXPECT_EQ(CallContext::kSequential, call_callsite.context()); - EXPECT_EQ(entry_node->GetCallSite(call), &call_callsite); + EXPECT_EQ(entry_node.GetCallSite(call), &call_callsite); - const CallSite& map_callsite = entry_node->callsites()[1]; + const CallSite& map_callsite = entry_node.callsites()[1]; EXPECT_EQ(map, map_callsite.instruction()); - EXPECT_MATCH(map_callsite.called_computations(), - UnorderedMatcher(subcomputation)); + EXPECT_THAT(map_callsite.called_computations(), + UnorderedElementsAre(subcomputation)); EXPECT_EQ(CallContext::kParallel, map_callsite.context()); - EXPECT_EQ(entry_node->GetCallSite(map), &map_callsite); + EXPECT_EQ(entry_node.GetCallSite(map), &map_callsite); - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* sub_node, - call_graph->GetNode(subcomputation)); - EXPECT_EQ(CallContext::kBoth, sub_node->context()); + const CallGraphNode& sub_node = call_graph->GetNode(subcomputation); + EXPECT_EQ(CallContext::kBoth, sub_node.context()); } TEST_F(CallGraphTest, ComplexGraph) { @@ -282,27 +271,24 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module.AddEntryComputation(builder.Build()); } - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); EXPECT_EQ(5, call_graph->nodes().size()); // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* entry_node, - call_graph->GetNode(entry_computation)); - ASSERT_EQ(1, entry_node->callsites().size()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + ASSERT_EQ(1, entry_node.callsites().size()); const std::vector& called_computations = - entry_node->callsites()[0].called_computations(); - EXPECT_MATCH(called_computations, - UnorderedMatcher(cond_computation, a_computation)); - EXPECT_EQ(CallContext::kSequential, entry_node->context()); - - TF_ASSIGN_OR_ASSERT_OK(const CallGraphNode* c_node, - call_graph->GetNode(c_computation)); - EXPECT_TRUE(c_node->callsites().empty()); - EXPECT_MATCH(c_node->callers(), - UnorderedMatcher(a_computation, b_computation)); - EXPECT_EQ(CallContext::kBoth, c_node->context()); + entry_node.callsites()[0].called_computations(); + EXPECT_THAT(called_computations, + UnorderedElementsAre(cond_computation, a_computation)); + EXPECT_EQ(CallContext::kSequential, entry_node.context()); + + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_TRUE(c_node.callsites().empty()); + EXPECT_THAT(c_node.callers(), + UnorderedElementsAre(a_computation, b_computation)); + EXPECT_EQ(CallContext::kBoth, c_node.context()); // Visit the graph and verify nodes were visited in callee-before-caller // order. @@ -335,15 +321,14 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { visited.push_back(node.computation()); return Status::OK(); })); - EXPECT_MATCH(visited, UnorderedMatcher(computation)); + EXPECT_THAT(visited, UnorderedElementsAre(computation)); } TEST_F(CallGraphTest, VisitUnreachableComputation) { @@ -353,8 +338,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module.AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module.AddEmbeddedComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); // Test visitation of only reachable nodes. { @@ -379,8 +363,8 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { }, /*visit_unreachable_nodes=*/true)); EXPECT_EQ(visited.size(), 2); - EXPECT_MATCH(visited, - UnorderedMatcher(entry_computation, unreachable_computation)); + EXPECT_THAT(visited, UnorderedElementsAre(entry_computation, + unreachable_computation)); } } @@ -388,15 +372,15 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. HloModule module(TestName()); module.AddEntryComputation(MakeScalarComputation()); - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr call_graph, - CallGraph::Build(&module)); + std::unique_ptr call_graph = CallGraph::Build(&module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.code(), tensorflow::error::INTERNAL); - ASSERT_MATCH(status.error_message(), testing::HasSubstr("Visitation failed")); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("Visitation failed")); } } // namespace diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc new file mode 100644 index 0000000000000000000000000000000000000000..86f7d6478244dec390b355f2c97a85d85d82a79c --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/compile_only_service.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/* static */ StatusOr> +CompileOnlyService::NewService(perftools::gputools::Platform* platform) { + ServiceOptions default_options; + default_options.set_platform(platform); + return NewService(default_options); +} + +/* static */ StatusOr> +CompileOnlyService::NewService(const ServiceOptions& options) { + perftools::gputools::Platform* platform = options.platform(); + if (platform == nullptr) { + TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); + } + + TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, + CreateComputeConstantBackend()); + std::unique_ptr service( + new CompileOnlyService(compiler, std::move(compute_constant_backend))); + return std::move(service); +} + +CompileOnlyService::CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend) + : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), + compiler_(compiler) { + runs_in_client_process_ = true; +} + +StatusOr>> +CompileOnlyService::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + for (const AotComputationInstance& instance : computations) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(instance.computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + // Dump computation proto state if flag is set. + legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); + const string& directory_path = flags->xla_dump_computations_to; + if (!directory_path.empty()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr session_module, + computation_tracker_.SnapshotComputation(versioned_handle.handle)); + string filename = tensorflow::strings::StrCat( + "computation_", versioned_handle.handle.handle(), "__", + session_module->entry().name(), "__version_", + versioned_handle.version); + TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, + *session_module)); + } + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + HloModuleConfig hlo_module_config(*program_shape); + auto* computation_layout = + hlo_module_config.mutable_entry_computation_layout(); + if (flags->xla_hlo_profile) { + hlo_module_config.enable_hlo_profiling(true); + } + for (int i = 0; i < instance.argument_layouts.size(); ++i) { + const Shape& argument_layout = *instance.argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *instance.result_layout)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule( + versioned_handle, &hlo_module_config, + /*include_unreachable_instructions=*/true)); + hlo_modules.push_back(std::move(hlo_module)); + } + + return compiler_->CompileAheadOfTime(std::move(hlo_modules), + MakeHloDumper(), options); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h new file mode 100644 index 0000000000000000000000000000000000000000..6dae49e3e1acf144847d44af4507880d8bf2efc4 --- /dev/null +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// An XLA Service specialization for ahead-of-time compilation. This only +// instantiates a Compiler object for the relevant platform; it does not +// instantiate or require an execution backend. +class CompileOnlyService : public Service { + public: + // Factory for creating a CompileOnlyService. The parameter platform is the + // platform that the service should target. If platform is null then the + // default platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform); + static StatusOr> NewService( + const ServiceOptions& options); + + // A description of a computation to compile using CompileAheadOfTime. + struct AotComputationInstance { + ComputationHandle computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |CompileOnlyClient::CompileAheadOfTime| for additional details. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& Options); + + // Override Service methods that require or imply the existence of an + // execute backend. Note that this does not include TransferToClient and + // TransferToClientInProcess, as computing contants produces global data + // that we may wish to transfer. + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { + return Unimplemented("CompileOnlyService does not support execution."); + } + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status TransferToServerInProcess( + const TransferToServerInProcessRequest* arg, + TransferToServerInProcessResponse* result) override { + return Unimplemented( + "CompileOnlyService does not support device data transfers."); + } + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { + return Unimplemented("CompileOnlyService does not support devices."); + } + + private: + explicit CompileOnlyService( + Compiler* compiler, std::unique_ptr compute_constant_backend); + CompileOnlyService(const CompileOnlyService&) = delete; + void operator=(const CompileOnlyService&) = delete; + + // The compiler for the target platform. This is included in place of + // the Service::execute_backend_'s compiler, since execute_backend_ is a + // nullptr in CompileOnlyService. + Compiler* compiler_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILE_ONLY_SERVICE_H_ diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 6f43c9b8040e9b21e7c0fcf86e2dc5b8ff8c6475..1876417c03a03ec80a05dac3d0936ef6db60055c 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -112,26 +112,22 @@ class Compiler { // // Use the overload below to compile computations that run in parallel. virtual StatusOr> Compile( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. virtual StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> - CompileAheadOfTime( - std::vector> module, - std::vector> module_config, - HloDumper dump_hlo, const AotCompilationOptions& options) = 0; + CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& options) = 0; ///// // The Compiler class also serves as a point to register compiler objects diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc index f78806bce82f7f524ba2bf80fbf602ad49e103c7..7e59f03773132b05590fd71d2e2e918d52fe5d98 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -169,6 +169,7 @@ void ComputationTracker::ComputeComputationPostOrder( StatusOr> ComputationTracker::BuildHloModule( const VersionedComputationHandle& entry_handle, + const HloModuleConfig* config, bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(computation_mutex_); @@ -208,7 +209,12 @@ StatusOr> ComputationTracker::BuildHloModule( string module_name = tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle); + std::unique_ptr module; + if (config == nullptr) { + module = MakeUnique(module_name, entry_handle); + } else { + module = MakeUnique(module_name, entry_handle, *config); + } for (auto versioned_handle : post_order) { UserComputation* computation = ResolveInternal(versioned_handle.handle).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h index 1922908747c6ef3b74c5b87d3c3924e5ffb38fc5..c7ca357398a9351ed8647fdef256b2af255eab0f 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ b/tensorflow/compiler/xla/service/computation_tracker.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" @@ -72,12 +73,15 @@ class ComputationTracker { // Builds an HLO module using the specified computation as the entry. The // module will include the entry computation as well as all computations which // are called directly or indirectly from the entry computation via operations - // like "map". If include_unreachable_instructions is true, then instructions + // like "map". config is the HLO module configuration to use for the + // constructed module; pass nullptr for "no configuration". + // If include_unreachable_instructions is true, then instructions // which are not reachable from the root are lowered into HloInstructions // including unreachable parameters. This ensures the entry HloComputation has // the same program shape (ProgramShape) as the entry UserComputation. StatusOr> BuildHloModule( const VersionedComputationHandle& entry_handle, + const HloModuleConfig* config, bool include_unreachable_instructions = true) const; string ToString() const; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 7dae49acad388e6d18a8cb1e4ea70244616978bb..3a1a9fe8709e33c7cfe56f4d8648ee2151e3bdd0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -16,19 +16,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include -#include -#include #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -36,6 +37,9 @@ namespace xla { namespace { +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; + // InstructionCopier encapsulates indices at which to copy 'instruction'. // All 'instruction' users in 'copy_users' are updated to use the copy. // @@ -52,7 +56,7 @@ namespace { // // Example two-element tuple with one element that needs a copy: // -// Tuple // instruction +// original-instruction // / \ // GTE(0) GTE(1) // | | @@ -60,23 +64,54 @@ namespace { // \ / // Tuple // copied-instruction // +// As an optimization, if the original instruction is itself a Tuple +// instruction, we elide the unnecessary extra GTE and Tuple instructions, +// and just insert the copy into a new Tuple instruction, with control +// dependencies to ensure the copy occurs after any possible interference. class InstructionCopier { public: - InstructionCopier(const bool init_value, HloInstruction* instruction, - const std::vector& copy_users); + InstructionCopier(HloInstruction* instruction, + const std::vector& copy_users) + : instruction_(instruction), + copy_users_(copy_users), + indices_to_copy_(instruction->shape()), + control_predecessors_(instruction->shape()) {} + + // Sets indices that are read-only, and thus do not need to be copied. + void SetReadOnlyIndices(const ShapeTree& read_only_indices) { + read_only_indices_ = read_only_indices; + } + + // Sets copy overrides, which are copy instructions to use at each index. This + // is used to share a single copy of read-only entry parameters and constants + // between multiple While loops. + void SetCopyOverrides(const ShapeTree& copy_overrides) { + copy_overrides_ = copy_overrides; + } // Returns true if all recorded indices are false (returns true otherwise). bool HasAllIndicesFalse() const; // Records instruction buffer indices which point-to a Parameter or Constant. - tensorflow::Status RecordIndicesWhichPointToParamOrConstant( + Status RecordIndicesWhichPointToParamOrConstant( const TuplePointsToAnalysis& points_to_analysis); // Records instruction buffer indices to copy which are necessary to ensure: // *) PointsToSet of 'instruction_' is unambiguous and distinct. // *) No liveness interference between 'instruction_' and 'other_instruction'. - tensorflow::Status RecordIndicesToCopyForColocatingBuffers( - BufferLiveness* liveness, HloInstruction* other_instruction); + // + // If 'read_only_indices_out' is non-null, read-only indices are set to true. + Status RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out); + + // Records control predecessors to add for inserted copy instructions. + // 'parameter' must have the same shape as the instruction that will be + // copied, and must define all buffers in the shape. Control predecessors are + // only recorded for indices that have already been marked for copying. + Status RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter); // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', // and replaces all uses for instructions in 'copy_users_' with copy. @@ -88,15 +123,29 @@ class InstructionCopier { const std::vector& copy_users() const { return copy_users_; } private: + // Does the given index represent a read-only buffer? + bool IsReadOnlyIndex(const ShapeIndex& index) const { + return !ShapeUtil::IsNil(read_only_indices_.shape()) && + read_only_indices_.element(index); + } + + // Returns the copy override at the given index, or nullptr. + HloInstruction* GetCopyOverride(const ShapeIndex& index) const { + return ShapeUtil::IsNil(copy_overrides_.shape()) + ? nullptr + : copy_overrides_.element(index); + } + // Records instruction buffer indices which have ambiguous or non-distinct // points-to sets. - tensorflow::Status RecordAmbiguousOrNonDistinctIndices( + Status RecordAmbiguousOrNonDistinctIndices( const TuplePointsToAnalysis& points_to_analysis); - // Records instruction buffer indices which have interferring live ranges + // Records instruction buffer indices which have interfering live ranges // with 'other_instruction' buffers at same index. - tensorflow::Status RecordIndicesWhichInterfereWithOtherInstruction( - BufferLiveness* liveness, HloInstruction* other_instruction); + Status RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out); // Recursively inserts copies of 'instruction' tuple elements at indices // specified in 'indices_to_copy', and returns the copy of 'instruction'. @@ -107,28 +156,25 @@ class InstructionCopier { } HloInstruction* instruction_; - std::vector copy_users_; + const std::vector copy_users_; ShapeTree indices_to_copy_; + ShapeTree> control_predecessors_; + ShapeTree read_only_indices_; + ShapeTree copy_overrides_; }; -InstructionCopier::InstructionCopier( - const bool init_value, HloInstruction* instruction, - const std::vector& copy_users) - : instruction_(instruction), - copy_users_(copy_users), - indices_to_copy_(instruction->shape(), init_value) {} - bool InstructionCopier::HasAllIndicesFalse() const { bool all_indices_false = true; - TF_CHECK_OK(indices_to_copy_.ForEachElement([&all_indices_false]( - const ShapeIndex& /*index*/, bool /*is_leaf*/, const bool& data) { - if (data) all_indices_false = false; - return tensorflow::Status::OK(); - })); + TF_CHECK_OK(indices_to_copy_.ForEachElement( + [&all_indices_false](const ShapeIndex& /*index*/, bool /*is_leaf*/, + bool data) { + if (data) all_indices_false = false; + return tensorflow::Status::OK(); + })); return all_indices_false; } -tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( +Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( const TuplePointsToAnalysis& points_to_analysis) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction_); @@ -141,41 +187,47 @@ tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( // Multiple buffers within a parameter/constant may be live out, so collect // a set of indices at which to copy first. - TF_RETURN_IF_ERROR(points_to.ForEachElement([this]( - const ShapeIndex& index, bool /*is_leaf*/, - const std::vector& buffers) { - for (auto buffer : buffers) { - // pointee is the HloInstruction producing the buffer which may be - // liveout. - HloInstruction* pointee = buffer->instruction(); - if (pointee->opcode() == HloOpcode::kParameter || - pointee->opcode() == HloOpcode::kConstant) { - VLOG(2) << "Parameter or constant buffer " << buffer->ToString() - << " index: " << tensorflow::str_util::Join(index, ",") - << " may be live out of computation: " << pointee->ToString(); - RecordIndex(index); - } - } - return tensorflow::Status::OK(); - })); - return tensorflow::Status::OK(); + TF_RETURN_IF_ERROR(points_to.ForEachElement( + [this](const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + if (IsReadOnlyIndex(index)) { + return Status::OK(); + } + for (const LogicalBuffer* buffer : buffers) { + // pointee is the HloInstruction producing the buffer which may be + // liveout. + HloInstruction* pointee = buffer->instruction(); + if (pointee->opcode() == HloOpcode::kParameter || + pointee->opcode() == HloOpcode::kConstant) { + VLOG(2) << "Parameter or constant buffer " << buffer->ToString() + << " index: " << tensorflow::str_util::Join(index, ",") + << " may be live out of computation: " + << pointee->ToString(); + RecordIndex(index); + break; + } + } + return Status::OK(); + })); + return Status::OK(); } -tensorflow::Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( - BufferLiveness* liveness, HloInstruction* other_instruction) { +Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out) { TF_RETURN_IF_ERROR( - RecordAmbiguousOrNonDistinctIndices(liveness->points_to_analysis())); + RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( - liveness, other_instruction)); - return tensorflow::Status::OK(); + liveness, other_instruction, read_only_indices_out)); + return Status::OK(); } -tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( +Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( const TuplePointsToAnalysis& points_to_analysis) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction_); // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - std::unordered_map> + FlatMap> buffer_to_source_indices; TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices]( const ShapeIndex& index, bool /*is_leaf*/, @@ -191,22 +243,18 @@ tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( } } // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (auto& buffer : buffers) { - auto it = buffer_to_source_indices.find(buffer); - if (it == buffer_to_source_indices.end()) { - buffer_to_source_indices.insert({buffer, std::vector()}); - } + for (const LogicalBuffer* buffer : buffers) { buffer_to_source_indices[buffer].push_back(index); } - return tensorflow::Status::OK(); + return Status::OK(); })); // Record all non-distinct indices detected in 'buffer_to_source_indices'. - for (auto& buff_to_src : buffer_to_source_indices) { + for (const auto& buff_to_src : buffer_to_source_indices) { if (buff_to_src.second.size() == 1) { continue; } - for (auto& src_index : buff_to_src.second) { + for (const ShapeIndex& src_index : buff_to_src.second) { // Record non-distinct points-to set at 'src_index'. if (!indices_to_copy_.element(src_index)) { VLOG(2) << "Adding copy of buffer for instruction: " @@ -217,23 +265,26 @@ tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( } } } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status -InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( - BufferLiveness* liveness, HloInstruction* other_instruction) { +Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree* read_only_indices_out) { // Record all buffer indices for 'instruction_', which interfere with // 'other_instruction' at the same index. TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape( instruction_->shape(), - [this, &liveness, &other_instruction](const Shape& /*subshape*/, - const ShapeIndex& index) { + [this, &liveness, other_instruction, read_only_indices_out]( + const Shape& /*subshape*/, const ShapeIndex& index) { + if (IsReadOnlyIndex(index)) { + return Status::OK(); + } if (indices_to_copy_.element(index)) { // Return if previous pass already set index. - return tensorflow::Status::OK(); + return Status::OK(); } - auto& points_to_analysis = liveness->points_to_analysis(); + const auto& points_to_analysis = liveness.points_to_analysis(); // Lookup buffers for 'instruction_' and 'other_instruction'. const std::vector instruction_buffers = points_to_analysis.GetPointsToSet(instruction_).element(index); @@ -252,20 +303,24 @@ InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( // then that buffer is not updated on the path between the two // instructions. Therefore, any other (possibly interference-causing) // users of that buffer from 'other_instruction' will see the same data, - // irrespecive of whether we insert a copy of this buffer at + // irrespective of whether we insert a copy of this buffer at // 'instruction_' or not. if (other_instruction_buffers.size() == 1 && other_instruction_buffers[0]->id() == instruction_buffer->id()) { - return tensorflow::Status::OK(); + if (read_only_indices_out != nullptr) { + *read_only_indices_out->mutable_element(index) = true; + } + return Status::OK(); } - // We cant say anything about the ambiguity of 'other_instruction' at + // We can't say anything about the ambiguity of 'other_instruction' at // this point, so we need to check interference between the single // buffer in the points-to set of 'instruction_' and all buffers in // 'other_instruction_buffers'. - for (auto& other_buffer : other_instruction_buffers) { - if (liveness->MayInterfere(*instruction_buffer, *other_buffer)) { + for (const LogicalBuffer* other_buffer : other_instruction_buffers) { + if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { VLOG(2) << "Adding copy of buffer for instruction: " << instruction_->name() + << " instruction_buffer: " << instruction_buffer->ToString() << " at index: " << tensorflow::str_util::Join(index, ",") << " because of interference with buffer: " << other_buffer->ToString(); @@ -273,40 +328,89 @@ InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( break; } } - return tensorflow::Status::OK(); + return Status::OK(); })); - return tensorflow::Status::OK(); + return Status::OK(); +} + +// This is called when 'instruction_' is a while body root, and 'parameter' is +// the while body parameter. We record all users of all aliases of 'parameter' +// as control predecessors, so that when we add a copy of 'instruction_', we can +// mark the control dependencies. This is necessary because points-to and +// liveness analysis doesn't know about the aliasing between the while body root +// and param. Without these control dependencies, the copy might get scheduled +// to run at a point that interferes with users of the buffer. +Status InstructionCopier::RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter) { + return indices_to_copy_.ForEachElement( + [this, &points_to_analysis, parameter](const ShapeIndex& index, + bool /*is_leaf*/, bool will_copy) { + if (will_copy) { + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + points_to_analysis.GetBufferDefinedAt(parameter, index)); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + user, points_to_analysis)) { + continue; + } + + if (user != instruction_) { + control_predecessors_.mutable_element(index)->push_back(user); + } + } + } + } + return Status::OK(); + }); } // Recursively inserts copies of 'instruction' tuple element buffers at // indices in 'indices_to_copy_', expanding tuples as needed. -// TODO(b/31159897) Remove superfluous Tuple->GTE->Tuple expressions. HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, ShapeIndex* index) { - std::vector element_copies; const int64 num_tuple_elements = ShapeUtil::TupleElementCount(instruction->shape()); + std::vector elem_copies(num_tuple_elements); for (int64 i = 0; i < num_tuple_elements; ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i)); - HloInstruction* element_copy; + HloInstruction* elem; + if (instruction->opcode() == HloOpcode::kTuple) { + // If the instruction is already a Tuple instruction, we know that the + // element buffers are aliased, so we can just grab the operand directly. + elem = instruction->mutable_operand(i); + } else { + // Otherwise we need to add a GTE to unpack the element out of the tuple. + elem = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + } index->push_back(i); - if (ShapeUtil::IsTuple(gte->shape())) { - element_copy = CopyTuple(gte, index); + if (ShapeUtil::IsTuple(elem->shape())) { + elem_copies[i] = CopyTuple(elem, index); + } else if (!indices_to_copy_.element(*index)) { + elem_copies[i] = elem; + } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { + elem_copies[i] = copy_override; } else { - if (indices_to_copy_.element(*index)) { - element_copy = gte->parent()->AddInstruction( - HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte)); - } else { - element_copy = gte; + HloInstruction* elem_copy = elem->parent()->AddInstruction( + HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); + for (HloInstruction* control_predecessor : + control_predecessors_.element(*index)) { + VLOG(2) << "Adding control dependency from " + << control_predecessor->ToString() << " to " + << elem_copy->ToString(); + TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); } + elem_copies[i] = elem_copy; } index->pop_back(); - element_copies.push_back(element_copy); } return instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); + HloInstruction::CreateTuple(elem_copies)); } // Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. @@ -327,8 +431,85 @@ HloInstruction* InstructionCopier::Copy() { return copy; } +// The 'read_only_indices' are initialized based on points-to analysis on the +// while body corresponding to 'while_hlo'. If the init buffer corresponding to +// a read-only index aliases with an entry parameter (or constant), it cannot be +// considered read-only, and must be copied. This is necessary because some +// backends don't support entry-parameter (or constant) aliasing with regular +// instructions. This function performs this fix-up of 'read_only_indices'. +// +// Returns a ShapeTree of copy_overrides, which implements an optimization to +// allow multiple while loops that share the same read-only entry parameters to +// share a single copy. +StatusOr> +RevertReadOnlyIndicesForEntryParamsAndConstants( + const HloInstruction* while_hlo, + const TuplePointsToAnalysis& points_to_analysis, + ShapeTree* read_only_indices, + FlatMap* shared_copies) { + const HloInstruction* init_hlo = while_hlo->operand(0); + const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + ShapeTree copy_overrides(init_hlo->shape()); + TF_RETURN_IF_ERROR(points_to.ForEachElement( + [init_hlo, read_only_indices, shared_copies, ©_overrides]( + const ShapeIndex& index, bool /*is_leaf*/, + const std::vector& buffers) { + // Look for read-only entry parameters. + if (!read_only_indices->element(index)) { + return Status::OK(); + } + for (const LogicalBuffer* buffer : buffers) { + HloInstruction* pointee = buffer->instruction(); + const HloComputation* computation = pointee->parent(); + const bool is_entry_parameter = + pointee->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); + const bool is_constant = pointee->opcode() == HloOpcode::kConstant; + if (!is_entry_parameter && !is_constant) { + continue; + } + // We have found an entry parameter or constant that is read-only in + // the while body. These buffers are managed by the caller, and cannot + // be aliased with non-parameter buffers. Revert this read-only index, + // to allow it to be copied. + *read_only_indices->mutable_element(index) = false; + + // Optimization to allow multiple while loops that share the same + // read-only entry parameters (or constants) to share a single copy. + // Only unambiguous array-shaped buffers are allowed, to reduce code + // complexity. The shape of the entry parameter must be identical to + // the shape of the init_hlo at this index, to ensure there were no + // intervening bitcast or GTE instructions, which are also hard to + // handle. + const Shape& pointee_shape = pointee->shape(); + const Shape& init_shape = + ShapeUtil::GetSubshape(init_hlo->shape(), index); + if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && + ShapeUtil::Equal(pointee_shape, init_shape)) { + HloInstruction** copy = &(*shared_copies)[pointee]; + if (*copy == nullptr) { + *copy = + pointee->parent()->AddInstruction(HloInstruction::CreateUnary( + pointee_shape, HloOpcode::kCopy, pointee)); + } + // Add the copy as an override. + *copy_overrides.mutable_element(index) = *copy; + } + + // We've already reverted the read-only index and handled the + // single-copy optimization above, so there's nothing more to do. + break; + } + return Status::OK(); + })); + return copy_overrides; +} + } // anonymous namespace +// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the +// base class, since the regular CopyInsertion logic above selectively copies +// tuple elements, while this method assumes all buffers need to be deep copied. StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { auto copy_it = inserted_copies_.find(hlo); if (copy_it == inserted_copies_.end()) { @@ -347,85 +528,96 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN( std::unique_ptr liveness, BufferLiveness::Run(module, MakeUnique(module))); - auto& points_to_analysis = liveness->points_to_analysis(); + const auto& points_to_analysis = liveness->points_to_analysis(); XLA_VLOG_LINES(2, points_to_analysis.ToString()); XLA_VLOG_LINES(2, module->ToString()); - // Gather references to all while body computations in 'module'. - std::unordered_set while_body_computations; - // Gather references to all while instructions in 'module' by computation. - std::unordered_map> - while_instructions; + // Gather all while body computations and while instructions. + FlatSet while_body_computations; + std::vector while_instructions; for (auto& computation : module->computations()) { for (auto& instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile) { - continue; - } - while_body_computations.insert(instruction->while_body()); - auto it = while_instructions.find(computation.get()); - if (it == while_instructions.end()) { - while_instructions.insert( - {computation.get(), std::vector()}); + if (instruction->opcode() == HloOpcode::kWhile) { + while_body_computations.insert(instruction->while_body()); + while_instructions.push_back(instruction.get()); } - while_instructions[computation.get()].emplace_back(instruction.get()); } } + // Collect instruction buffer indices to copy in 'instructions_to_copy'. + std::vector instructions_to_copy; + + // Add copies of computation root instructions, if needed. + FlatMap> while_body_read_only_indices; for (auto& computation : module->computations()) { VLOG(2) << "computation " << computation->name(); - - // Collect instruction buffer indices to copy in 'instructions_to_copy'. - std::vector instructions_to_copy; - - // Add copies of while 'init' operand instructions (if needed). - // TODO(b/33301720) Remove redundant while instruction copies. - auto it = while_instructions.find(computation.get()); - if (it != while_instructions.end()) { - for (auto& while_hlo : it->second) { - // Create InstructionCopier for init operand of while instruction. - HloInstruction* init_hlo = while_hlo->mutable_operand(0); - instructions_to_copy.push_back( - InstructionCopier(/*init_value=*/false, init_hlo, {while_hlo})); - InstructionCopier& init_copier = instructions_to_copy.back(); - // Record 'init' buffer indices which point-to a Constant or Parameter. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( - liveness->points_to_analysis())); - // Record indices necessary to colocate while and init operand buffers. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( - liveness.get(), while_hlo)); - } - } - - // Create InstructionCopier for computation root instruction. - instructions_to_copy.push_back(InstructionCopier( - /*init_value=*/false, computation->root_instruction(), {})); - InstructionCopier& root_copier = instructions_to_copy.back(); - + InstructionCopier root_copier(computation->root_instruction(), + /*copy_users=*/{}); if (while_body_computations.count(computation.get()) > 0) { - // Record root indices to copy for while body sub-computations. - // We do not need to call RecordIndicesWhichPointToParamOrConstant for - // the while root instruction here, because any neccessary copies needed - // to avoid constant or parameters in the output are handled by while.init - // operand copy insertion above (which will share an allocation). + // Record root indices to copy for while body sub-computations. We do not + // need to call RecordIndicesWhichPointToParamOrConstant for the while + // body root instruction here, because any necessary copies needed to + // avoid constants or parameters in the output are handled by while.init + // operand copy insertion below (which will share an allocation). + HloInstruction* while_body_param = computation->parameter_instruction(0); + ShapeTree read_only_indices(while_body_param->shape()); TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( - liveness.get(), computation->parameter_instruction(0))); - } else if (copy_param_and_const_) { + *liveness, while_body_param, &read_only_indices)); + while_body_read_only_indices[computation.get()] = read_only_indices; + + // Mark control predecessors, based on the body param, for any copies + // we'll be inserting. This ensures the copy doesn't run too early. + TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( + points_to_analysis, while_body_param)); + } else { // Record root indices to copy for general computations. TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( - liveness->points_to_analysis())); + points_to_analysis)); } + instructions_to_copy.push_back(root_copier); + } - for (auto& to_copy : instructions_to_copy) { - if (to_copy.HasAllIndicesFalse()) { - continue; - } - changed = true; + // Add copies of while 'init' operand instructions, if needed. 'shared_copies' + // is used to ensure that multiple while loops can share a single copy of the + // same entry parameter or constant, if all loops use it read-only. + // + // TODO(b/33301720) Remove redundant while instruction copies. + FlatMap shared_copies; + for (HloInstruction* while_hlo : while_instructions) { + // Fix read_only_indices to account for entry parameters and constants. Also + // initialize copy_overrides, which ensures a single copy for each read-only + // entry parameter or constant that is used in multiple while loops. + ShapeTree* read_only_indices = + &while_body_read_only_indices[while_hlo->while_body()]; + TF_ASSIGN_OR_RETURN( + const ShapeTree copy_overrides, + RevertReadOnlyIndicesForEntryParamsAndConstants( + while_hlo, points_to_analysis, read_only_indices, &shared_copies)); + // Create InstructionCopier for init operand of while instruction. + HloInstruction* init_hlo = while_hlo->mutable_operand(0); + InstructionCopier init_copier(init_hlo, {while_hlo}); + init_copier.SetReadOnlyIndices(*read_only_indices); + init_copier.SetCopyOverrides(copy_overrides); + // Record 'init' buffer indices which point-to a Constant or Parameter. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + // Record indices necessary to colocate while and init operand buffers. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); + instructions_to_copy.push_back(init_copier); + } - // Copy instruction at recorded buffer indices. - HloInstruction* copy = to_copy.Copy(); - if (to_copy.instruction() == computation->root_instruction()) { - computation->set_root_instruction(copy); - } + for (InstructionCopier& to_copy : instructions_to_copy) { + if (to_copy.HasAllIndicesFalse()) { + continue; + } + changed = true; + + // Copy instruction at recorded buffer indices. + HloComputation* computation = to_copy.instruction()->parent(); + HloInstruction* copy = to_copy.Copy(); + if (to_copy.instruction() == computation->root_instruction()) { + computation->set_root_instruction(copy); } } diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index ce91ac0de56f3fc1101c38cee838c0b0593214ad..28bb62e40c7674960dbb1bb63dc8967b06956028 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -32,9 +33,6 @@ namespace xla { // different lifetimes than computation results. class CopyInsertion : public HloPassInterface { public: - explicit CopyInsertion(bool copy_param_and_const = true) - : copy_param_and_const_(copy_param_and_const) {} - ~CopyInsertion() override {} tensorflow::StringPiece name() const override { return "copy-insertion"; } // Run the pass on the given module. Returns whether the module was changed @@ -46,13 +44,9 @@ class CopyInsertion : public HloPassInterface { // duplicate copies. StatusOr FindOrInsertCopy(HloInstruction* hlo); - // Determines whether to insert copies if the root instruction is, or - // points-to, any constant or parameter instruction. - const bool copy_param_and_const_; - // A map containing all copies inserted during the copy insertion pass. The // key is the copied instruction and the value is the copy. - std::unordered_map inserted_copies_; + tensorflow::gtl::FlatMap inserted_copies_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 4c26b2de124b0b42f6de1ebdf82d4584f2904cab..661f682e38a3cefd09f36eb0e42084d35491e196 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -20,18 +20,23 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xla/test_helpers.h" +namespace op = xla::testing::opcode_matchers; namespace xla { namespace { +using ::testing::UnorderedElementsAre; + class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { @@ -51,43 +56,6 @@ class CopyInsertionTest : public HloTestBase { EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); } } - - // OperandTree is a test helper class that simplifies the expression of - // an expected tree of operands (starting at some root instruction) in a - // unit test. - // Each HLO instruction is represented as a node in the OperandTree. - struct OperandTree { - // The expected opcode for this OperandTree node. - HloOpcode opcode; - // The set of operands expected for this OperandTree node. - std::vector operands; - // If non-null, a pointer to the expected HloInstruction at this node. - const HloInstruction* instruction = nullptr; - - // Returns a mutable reference to operand 'i' of this node. - OperandTree& op(int i) { - if (i >= operands.size()) { - operands.resize(i + 1); - } - return operands[i]; - } - - // Check that 'instruction' and its operands match expected values recorded - // in OperandTree. - void Check(const HloInstruction* instruction) { - EXPECT_EQ(opcode, instruction->opcode()); - if (instruction != nullptr) { - EXPECT_EQ(instruction, instruction); - } - if (operands.empty()) { - return; - } - EXPECT_EQ(operands.size(), instruction->operand_count()); - for (int i = 0; i < instruction->operand_count(); ++i) { - operands[i].Check(instruction->operand(i)); - } - } - }; }; TEST_F(CopyInsertionTest, SingleParameter) { @@ -97,25 +65,16 @@ TEST_F(CopyInsertionTest, SingleParameter) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({x})); - ExpectEqUnordered(x->users(), {tuple}); + EXPECT_THAT(x->users(), UnorderedElementsAre(tuple)); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, SingleConstant) { @@ -125,25 +84,16 @@ TEST_F(CopyInsertionTest, SingleConstant) { HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); - ExpectEqUnordered(constant->users(), {tuple}); + EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple)); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -172,30 +122,10 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - // "constant2" and parameter "x" are pointed to by the tuple and should be - // copied. - - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.op(2).opcode = HloOpcode::kGetTupleElement; - op_tree.op(2).op(0).opcode = HloOpcode::kTuple; - op_tree.op(2).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)), + op::Copy(old_root->operand(1)), old_root->operand(2))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -219,32 +149,19 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - ExpectEqUnordered(constant1->users(), {tuple1}); - ExpectEqUnordered(constant2->users(), {tuple1, tuple2}); - ExpectEqUnordered(constant3->users(), {tuple2}); + EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1)); + EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2)); + EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2)); HloModule module(TestName()); module.AddEntryComputation(builder.Build()); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kSelect; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kSelect; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(old_root)), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, BitcastParameter) { @@ -259,19 +176,13 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - ExpectEqUnordered(x->users(), {bitcast}); + EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kBitcast; - op_tree.op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Copy(old_root)); } TEST_F(CopyInsertionTest, BitcastConstant) { @@ -287,19 +198,13 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - ExpectEqUnordered(constant->users(), {bitcast}); + EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast)); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kBitcast; - op_tree.op(0).instruction = old_root; - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Copy(old_root)); } TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { @@ -314,21 +219,13 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { HloModule module(TestName()); module.AddEntryComputation(builder.Build()); - ExpectEqUnordered(x->users(), {bitcast}); + EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -339,10 +236,11 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { // Param shape is: ((F32[], S32[1,2,3]), F32[42]) builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {1, 2, 3})}), - ShapeUtil::MakeShape(F32, {42})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), "param0")); HloModule module(TestName()); @@ -356,30 +254,13 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { HloInstruction* new_root = module.entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(0).op(0).op(0).op(0).op(0).instruction = old_root; - - op_tree.op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(1).opcode = HloOpcode::kCopy; - op_tree.op(0).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(1).op(0).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(0).op(1).op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kParameter; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT( + new_root, + op::Tuple( + op::Tuple( + op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))), + op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { @@ -389,10 +270,11 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { // Param shape is: ((F32[], S32[1,2,3]), F32[42]) auto param = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), - ShapeUtil::MakeShape(S32, {1, 2, 3})}), - ShapeUtil::MakeShape(F32, {42})}), + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(S32, {1, 2, 3})}), + ShapeUtil::MakeShape(F32, {42})}), "param0")); // The return value of the computation is the zero-th elemnt of the nested @@ -407,23 +289,10 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(old_root)), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -456,15 +325,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* old_root = module.entry_computation()->root_instruction(); InsertCopies(&module); - HloInstruction* new_root = module.entry_computation()->root_instruction(); - // Check path from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kCopy; - op_tree.op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(module.entry_computation()->root_instruction(), + op::Copy(old_root)); } class WhileCopyInsertionTest : public CopyInsertionTest { @@ -528,7 +391,6 @@ class WhileCopyInsertionTest : public CopyInsertionTest { } // Builds a While body computation with read-only tuple element 0. - // both input tuple elements. // EX: // Body({in0, in1}) // out0 = in0 @@ -563,11 +425,14 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // out0 = Add(in0, 1) // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) // Tuple(out0, out1) - std::unique_ptr BuildIndependentBodyComputation() { + std::unique_ptr BuildIndependentBodyComputation( + bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. + const Shape& loop_state_shape = + nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); // Update the induction variable GTE(0). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -578,16 +443,30 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). - auto data = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + HloInstruction* data = nullptr; + if (nested) { + data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + nested_tuple_shape_, loop_state, 1)); + data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, data, 0)); + } else { + data = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); + } auto update = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR1( {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); - // add0 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) + // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + if (nested) { + auto nested_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add1})); + builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple})); + } else { + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + } return builder.Build(); } @@ -640,8 +519,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // Builds a While instruction using 'condition' and 'body' sub-computations. // Init operand is initialized to zeros of appropriate shape. - void BuildWhileInstruction(HloComputation* condition, HloComputation* body, - bool nested = false) { + HloInstruction* BuildWhileInstruction(HloComputation* condition, + HloComputation* body, + bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); @@ -655,17 +535,18 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction::CreateTuple({data_init, data_init})); auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); - builder.AddInstruction(HloInstruction::CreateWhile( + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition, body, loop_state_init)); module_.AddEntryComputation(builder.Build()); - return; + return while_hlo; } auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, data_init})); - builder.AddInstruction(HloInstruction::CreateWhile( + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition, body, loop_state_init)); module_.AddEntryComputation(builder.Build()); + return while_hlo; } HloInstruction* BuildWhileInstruction_InitPointsToConstant() { @@ -743,12 +624,14 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstructionWithCustomInit( const Shape& loop_state_shape, HloInstruction* data_init, HloComputation::Builder* builder) { + const bool nested = + ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); auto condition = - module_.AddEmbeddedComputation(BuildConditionComputation()); + module_.AddEmbeddedComputation(BuildConditionComputation(nested)); auto body = - module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); + module_.AddEmbeddedComputation(BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( HloInstruction::CreateTuple({induction_var_init, data_init})); auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile( @@ -781,14 +664,20 @@ class WhileCopyInsertionTest : public CopyInsertionTest { TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); auto body = module_.AddEmbeddedComputation(BuildIndependentBodyComputation()); - BuildWhileInstruction(condition, body); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); InsertCopies(&module_); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); // No copies should be inserted so root should not be updated. - CHECK_EQ(old_root, new_root); + EXPECT_EQ(old_root, new_root); + + // Both init indices need copies. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with dependent tuple elements: @@ -798,39 +687,25 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // out1 = Add(BCast(in0), in1) // Tuple(out0, out1) // -// CopyInsertion pass should generate: +// CopyInsertion pass should convert the root instruction to: // -// Tuple // old root -// / \ -// GTE(0) GTE(1) -// | | -// Copy | -// \ / -// Tuple // new root +// Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); auto body = module_.AddEmbeddedComputation(BuildDependentBodyComputation()); - BuildWhileInstruction(condition, body); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); InsertCopies(&module_); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(new_root, + op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with read-only tuple element 0: @@ -846,20 +721,110 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // \ / // TUPLE (root) // -// CopyInsertion pass should not generate any copies. -// +// CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { auto condition = module_.AddEmbeddedComputation(BuildConditionComputation()); auto body = module_.AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - BuildWhileInstruction(condition, body); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); HloInstruction* old_root = body->root_instruction(); InsertCopies(&module_); HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - // No copies should be inserted so root should not be updated. - CHECK_EQ(old_root, new_root); + // No copies should be inserted in the body, so root should not be updated. + EXPECT_EQ(old_root, new_root); + + // Both indices need copies, even though Index 0 is read-only, since both are + // constants, which must be copied. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); +} + +// Same as above, but with two while loops, sharing entry parameters. +TEST_F(WhileCopyInsertionTest, + DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { + auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body1 = module_.AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + auto body2 = module_.AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_param, data_param})); + + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition2, body2, loop_init)); + module_.AddEntryComputation(builder.Build()); + + InsertCopies(&module_); + + // Both while loops share a single copy of iter_param, since index 0 is + // read-only in the body. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), + while_hlo2->operand(0)->operand(0)); + EXPECT_THAT(while_hlo1->operand(0)->operand(0), op::Copy(iter_param)); + + // Each while loop gets its own copy of data_param, since index 1 is not + // read-only in the body. + EXPECT_NE(while_hlo1->operand(0)->operand(1), + while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); +} + +// Same as above, but with two while loops, sharing non-parameters. +TEST_F(WhileCopyInsertionTest, + DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { + auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation()); + auto body1 = module_.AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + auto body2 = module_.AddEmbeddedComputation( + BuildDependentBodyOneReadOnlyComputation()); + + auto builder = HloComputation::Builder(TestName() + ".While"); + auto iter_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, induction_variable_shape_, "iter")); + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "data")); + // Add dummy ops to ensure loop_init elements aren't entry parameters. + auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary( + iter_param->shape(), HloOpcode::kExp, iter_param)); + auto data_value = builder.AddInstruction(HloInstruction::CreateUnary( + data_param->shape(), HloOpcode::kExp, data_param)); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({iter_value, data_value})); + + auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition1, body1, loop_init)); + auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition2, body2, loop_init)); + module_.AddEntryComputation(builder.Build()); + + InsertCopies(&module_); + + // No copies of iter_value are necessary, since index 0 is read-only in both + // while bodies. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); + EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); + + // Each while loop gets its own copy of data_value, since index 1 is not + // read-only in the body. + EXPECT_NE(while_hlo1->operand(0)->operand(1), + while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); } // Tests while body computation with nested tuple elements: @@ -872,7 +837,8 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { // Add Reverse // | | // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with the +// actual GTE and Tuple instructions optimized away: // // Tuple // old root // / \ @@ -898,104 +864,41 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) { HloInstruction* old_root = body->root_instruction(); InsertCopies(&module_); - HloInstruction* new_root = body->root_instruction(); - - // Check all paths from 'new_root' to 'old_root'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).instruction = old_root; - - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).instruction = old_root; - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_root; - - op_tree.Check(new_root); + EXPECT_THAT(body->root_instruction(), + op::Tuple(old_root->operand(0), + op::Tuple(old_root->operand(1)->operand(0), + op::Copy(old_root->operand(1)->operand(1))))); } // Tests while init instruction which points-to a constant. // // init = Tuple(Constant(S32, {}), Constant(F32, {8})) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should add copies for both constants. // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); auto old_init = while_hlo->operand(0); InsertCopies(&module_); - auto new_init = while_hlo->operand(0); - - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which points-to a parameter. // // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should add copies for both the constant and parameter. // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); auto old_init = while_hlo->operand(0); InsertCopies(&module_); - auto new_init = while_hlo->operand(0); - - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which has an ambiguous points-to set. @@ -1003,7 +906,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // select = Select(pred, tuple1, tuple2) // init = Tuple(Constant(S32, {}), Parameter(F32, {8})) // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with some of +// the actual GTE and Tuple instructions optimized away: // // Tuple // old init // / \ @@ -1025,39 +929,21 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); auto old_init = while_hlo->operand(0); InsertCopies(&module_); - auto new_init = while_hlo->operand(0); - - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple( + op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), + op::Copy(op::GetTupleElement(old_init->operand(1)))))); } // Tests while init instruction which has a non-distinct points-to set. // // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one})) // -// CopyInsertion pass should generate: +// CopyInsertion pass will conceptually generate the following, but with some of +// the actual GTE and Tuple instructions optimized away: // // Tuple // old init // / \ @@ -1079,71 +965,28 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); auto old_init = while_hlo->operand(0); InsertCopies(&module_); - auto new_init = while_hlo->operand(0); - - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kTuple; - - op_tree.op(1).op(0).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init; - op_tree.op(1).op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init; - - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(old_init->operand(1)->operand(0)), + op::Copy(old_init->operand(1)->operand(0))))); } -// Tests while init instruction buffer which interfers with while result buffer. +// Tests while init instruction buffer which interferes with while result buffer. // // init_data = Broadcast(...) // add_unrelated = Add(init_data) // takes a reference to cause interference // init = Tuple(Constant(S32, {}), init_data)) // -// CopyInsertion pass should generate: -// -// Tuple // old init -// / \ -// GTE(0) GTE(1) -// | | -// Copy Copy -// \ / -// Tuple // new init +// CopyInsertion pass should copy both operands. // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); auto old_init = while_hlo->operand(0); InsertCopies(&module_); - auto new_init = while_hlo->operand(0); - - // Check all paths from 'new_init' to 'old_init'. - OperandTree op_tree; - op_tree.opcode = HloOpcode::kTuple; - - op_tree.op(0).opcode = HloOpcode::kCopy; - op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(0).op(0).op(0).instruction = old_init; - - op_tree.op(1).opcode = HloOpcode::kCopy; - op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement; - op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple; - op_tree.op(1).op(0).op(0).instruction = old_init; - op_tree.Check(new_init); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e9963528111994c4918861eaa52ab915fe34fd93..affb5f99066d8278c583c469d97e78646d52f3c6 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,13 +53,13 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/port:initialize", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -97,6 +97,7 @@ cc_library( name = "simple_orc_jit", srcs = ["simple_orc_jit.cc"], hdrs = ["simple_orc_jit.h"], + linkopts = ["-ldl"], deps = [ ":compiler_functor", ":cpu_runtime", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d18141af83e4653e18d3b0118d0892f41db5b69b..b42702dbe1abe3db838159bda2665743e416a2d5 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" @@ -28,6 +29,8 @@ limitations under the License. namespace xla { namespace cpu { +using ::testing::ElementsAre; + class ConvCanonicalizationTest : public HloTestBase { public: ConvCanonicalizationTest() { @@ -96,17 +99,14 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { // The input is in CNHW order. input_reshape should produce // NHWC for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(input_reshape->dimensions(), - std::vector({1, 2, 3, 0}))); + EXPECT_THAT(input_reshape->dimensions(), ElementsAre(1, 2, 3, 0)); // The kernel is in OIHW order. kernel_reshape should produce // HWIO for the convolution to hit the Eigen fast path. - EXPECT_TRUE(ContainersEqual(kernel_reshape->dimensions(), - std::vector({2, 3, 1, 0}))); + EXPECT_THAT(kernel_reshape->dimensions(), ElementsAre(2, 3, 1, 0)); // The output of the canonical convolution is in NHWC order (the same as // input_reshape's order). output_reshape should restore that order to the // order of the computation root (CNHW). - EXPECT_TRUE(ContainersEqual(output_reshape->dimensions(), - std::vector({3, 0, 1, 2}))); + EXPECT_THAT(output_reshape->dimensions(), ElementsAre(3, 0, 1, 2)); } TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index c5433d4b89d7ccab0f04e9ab2787ce150417b669..97458f0fcc344f42f6d7244b1f812e29666437e1 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -39,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/port/initialize.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -58,6 +57,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -192,16 +192,16 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { } // It is important to recurse for "while" or else we risk overly coarse // profiling information. - Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, - HloComputation* condition, HloComputation* body) override { + Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); - TF_RETURN_IF_ERROR( - condition->root_instruction()->Accept(&candidates_for_condition)); + TF_RETURN_IF_ERROR(xla_while->while_condition()->root_instruction()->Accept( + &candidates_for_condition)); CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); - TF_RETURN_IF_ERROR(body->root_instruction()->Accept(&candidates_for_body)); + TF_RETURN_IF_ERROR(xla_while->while_body()->root_instruction()->Accept( + &candidates_for_body)); return Status::OK(); } @@ -210,9 +210,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* hlo_module, - HloModuleConfig* module_config, - HloDumper dump_hlo) { +Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { // Optimization pipeline. HloPassPipeline pipeline("CPU", dump_hlo); pipeline.AddInvariantChecker(); @@ -232,12 +230,18 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, pass.AddPass(); pass.AddPass(); } - pipeline.AddPass(PotentiallyImplementedAsEigenDot); - pipeline.AddPass(); + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot) + ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); pipeline.AddPass( - module_config->mutable_entry_computation_layout()); + module->mutable_config()->mutable_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -250,10 +254,13 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, if (flags->xla_cpu_parallel) { pipeline.AddPass(); } - // Copy insertion should be performed immediately before IR emission to - // avoid inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes - // an instruction which materializes a value). + // Copy insertion should be performed immediately before IR emission to avoid + // inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes an + // instruction which materializes a value). DCE must be run immediately before + // (and sometime after) copy insertion, to avoid dead code from interfering + // with the rewrites. + pipeline.AddPass(); pipeline.AddPass(); if (flags->xla_cpu_parallel) { // Re-run the outlining, in case any copies were inserted into the entry @@ -261,7 +268,8 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module, pipeline.AddPass(); } pipeline.AddPass(); - return pipeline.Run(hlo_module).status(); + pipeline.AddPass(); + return pipeline.Run(module).status(); } namespace { @@ -295,8 +303,7 @@ llvm::CodeGenOpt::Level CodeGenOptLevel() { } // namespace StatusOr> CpuCompiler::Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); @@ -304,17 +311,16 @@ StatusOr> CpuCompiler::Compile( auto llvm_context = MakeUnique(); auto llvm_module = MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique(CompilerTargetOptions(*module_config), + auto jit = MakeUnique(CompilerTargetOptions(module->config()), CodeGenOptLevel()); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR( - RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), dump_hlo)); - HloComputation* computation = hlo_module->entry_computation(); + HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; - if (module_config->hlo_profiling_enabled()) { + if (module->config().hlo_profiling_enabled()) { TF_ASSIGN_OR_RETURN( hlo_to_profile_idx, CollectProfileCandidates::GetCandidatesForComputation(computation)); @@ -331,8 +337,8 @@ StatusOr> CpuCompiler::Compile( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(hlo_module.get(), - MakeUnique(hlo_module.get()), + BufferAssigner::Run(module.get(), + MakeUnique(module.get()), [this](const LogicalBuffer& buffer) { return ShapeSizeBytes(buffer.shape()); }, @@ -363,13 +369,13 @@ StatusOr> CpuCompiler::Compile( // The parallel preparation should have ensured that the top-level // computation consists solely of Call instructions. TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall) - << hlo_module->ToString(); + << module->ToString(); HloComputation* to_apply = instruction->to_apply(); parallel_computations.emplace(to_apply, instruction); } - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, - llvm_module.get(), &hlo_to_profile_idx); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + &hlo_to_profile_idx); std::unique_ptr> function_names( new std::map()); for (auto embedded_computation : @@ -403,9 +409,9 @@ StatusOr> CpuCompiler::Compile( // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( - std::move(jit), std::move(assignment), std::move(hlo_module), - std::move(module_config), std::move(function_names), - std::move(hlo_to_profile_idx), std::move(aligned_constants))); + std::move(jit), std::move(assignment), std::move(module), + std::move(function_names), std::move(hlo_to_profile_idx), + std::move(aligned_constants))); if (flags->xla_cpu_embed_ir) { static_cast(*cpu_executable) @@ -417,7 +423,7 @@ StatusOr> CpuCompiler::Compile( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*hlo_module, + CreateMemoryMinimizingSequence(*module, [this](const LogicalBuffer& buffer) { return ShapeSizeBytes(buffer.shape()); })); @@ -426,20 +432,20 @@ StatusOr> CpuCompiler::Compile( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(hlo_module.get(), - MakeUnique(hlo_module.get(), - module_sequence), - [this](const LogicalBuffer& buffer) { - return ShapeSizeBytes(buffer.shape()); - }, - kMemoryAlignment)); + BufferAssigner::Run( + module.get(), + MakeUnique(module.get(), module_sequence), + [this](const LogicalBuffer& buffer) { + return ShapeSizeBytes(buffer.shape()); + }, + kMemoryAlignment)); // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, - llvm_module.get(), &hlo_to_profile_idx); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + &hlo_to_profile_idx); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -466,10 +472,9 @@ StatusOr> CpuCompiler::Compile( // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); - cpu_executable.reset( - new CpuExecutable(std::move(jit), std::move(assignment), - std::move(hlo_module), std::move(module_config), - function_name, std::move(hlo_to_profile_idx))); + cpu_executable.reset(new CpuExecutable( + std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(hlo_to_profile_idx))); if (flags->xla_cpu_embed_ir) { static_cast(*cpu_executable) @@ -481,27 +486,24 @@ StatusOr> CpuCompiler::Compile( } StatusOr>> CpuCompiler::Compile( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector> modules, HloDumper dump_hlos, + std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); } StatusOr>> -CpuCompiler::CompileAheadOfTime( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlo, const AotCompilationOptions& aot_options) { - TF_RET_CHECK(hlo_modules.size() == module_configs.size()); - TF_RET_CHECK(!hlo_modules.empty()); +CpuCompiler::CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& aot_options) { + TF_RET_CHECK(!modules.empty()); // We can pass just one llvm::TargetOptions when we compile the LLVM module, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. - bool fast_math_disabled = module_configs[0]->fast_math_disabled(); - for (const auto& module_config : module_configs) { - if (module_config->fast_math_disabled() != fast_math_disabled) { + bool fast_math_disabled = modules[0]->config().fast_math_disabled(); + for (const auto& module : modules) { + if (module->config().fast_math_disabled() != fast_math_disabled) { return InvalidArgument( "All HLO module configs must have the same value for " "fast_math_disabled."); @@ -559,7 +561,7 @@ CpuCompiler::CompileAheadOfTime( std::unique_ptr target_machine = WrapUnique(target->createTargetMachine( triple.getTriple(), cpu_name, features, - CompilerTargetOptions(*module_configs[0]), reloc_model, + CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::CodeModel::Default, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. @@ -575,15 +577,14 @@ CpuCompiler::CompileAheadOfTime( } std::vector> results; - for (size_t i = 0; i < hlo_modules.size(); ++i) { - HloModule* hlo_module = hlo_modules[i].get(); - HloModuleConfig* module_config = module_configs[i].get(); + for (size_t i = 0; i < modules.size(); ++i) { + HloModule* module = modules[i].get(); - TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo)); + TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo)); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*hlo_module, + CreateMemoryMinimizingSequence(*module, [this](const LogicalBuffer& buffer) { return ShapeSizeBytes(buffer.shape()); })); @@ -593,16 +594,15 @@ CpuCompiler::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - hlo_module, - MakeUnique(hlo_module, module_sequence), + module, MakeUnique(module, module_sequence), [this](const LogicalBuffer& buffer) { return ShapeSizeBytes(buffer.shape()); }, kMemoryAlignment)); - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module, + IrEmitter ir_emitter(*module, *assignment, &llvm_module, /*hlo_to_profile_idx=*/nullptr); - HloComputation* computation = hlo_module->entry_computation(); + HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -672,8 +672,10 @@ int64 CpuCompiler::ShapeSizeBytes(const Shape& shape) const { } // namespace cpu } // namespace xla -REGISTER_MODULE_INITIALIZER(cpu_compiler, { +static bool InitModule() { xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { return xla::MakeUnique(); }); -}); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index a32aa84ea51123f76551ad617cc914a53d4ca4d1..cadafa83320e17e6baddfc64dcaa8a988de6360d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" @@ -113,21 +112,17 @@ class CpuCompiler : public Compiler { ~CpuCompiler() override {} StatusOr> Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> - CompileAheadOfTime( - std::vector> module, - std::vector> module_config, - HloDumper dump_hlo, const AotCompilationOptions& options) override; + CompileAheadOfTime(std::vector> modules, + HloDumper dump_hlo, + const AotCompilationOptions& options) override; perftools::gputools::Platform::Id PlatformId() const override; @@ -139,8 +134,7 @@ class CpuCompiler : public Compiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* hlo_module, HloModuleConfig* module_config, - HloDumper dump_hlo); + Status RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 88283e6010ea784e2a977a80adbe6315782f7fdc..a4fcea7aec83fc64fa40fc28d4713a651290641c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -53,11 +52,9 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr module_config, - const string& entry_function_name, + std::unique_ptr hlo_module, const string& entry_function_name, std::unordered_map hlo_to_profile_idx) - : Executable(std::move(hlo_module), std::move(module_config)), + : Executable(std::move(hlo_module)), jit_(std::move(jit)), assignment_(std::move(assignment)), hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index b04b4e8dd1fd23839a4684f72622e32eca9c3730..0cc0965ae1df6ab64a2f146e02b6e19b43ca81a5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -52,7 +51,6 @@ class CpuExecutable : public Executable { std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, const string& entry_function_name, std::unordered_map hlo_to_profile_idx); ~CpuExecutable() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 240da35ef190eb7080947ab7d1da91d8d2dd8973..dc002846e9e6b07c767ddc8af939657c4c51bf23 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -24,6 +24,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on CPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + // Condition for consumer: must be elementwise or a fusion op // (which necessarily only contains elementwise operations) if (!(consumer->opcode() == HloOpcode::kFusion || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8e06f0520edfb05c7ec606dcb8e85c5ef997c2c0..253de20f25127bf0ac23d5969e0f16c143396e47 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include #include #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 51c6dc4426f8c40d60ba933ce0a31f8fb9d927c1..2d81ba7882747f77ca93adf71a37172f5f2bff24 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -63,8 +63,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const HloModuleConfig& hlo_module_config, - const BufferAssignment& assignment, llvm::Module* llvm_module, + const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx) : assignment_(assignment), module_(llvm_module), @@ -72,8 +72,8 @@ IrEmitter::IrEmitter( ir_builder_(llvm_module->getContext()), hlo_to_profile_idx_(hlo_to_profile_idx), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), - hlo_module_config_(hlo_module_config) { - ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(hlo_module_config)); + hlo_module_config_(hlo_module.config()) { + ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(hlo_module_config_)); } StatusOr IrEmitter::EmitComputation( @@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name, if (&argument == retval) { continue; } - compute_function_->setDoesNotAlias(argument.getArgNo() + 1); + compute_function_->addAttribute(argument.getArgNo() + 1, + llvm::Attribute::NoAlias); } ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( @@ -1136,6 +1137,41 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { + if (ShapeUtil::IsScalar(slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(slice)); + emitted_value_[slice] = target_address; + return EmitMemcpy(*operand, *slice); + } + return DefaultAction(slice); +} + +Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* /*start_indices*/) { + if (ShapeUtil::IsScalar(dynamic_slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_slice)); + emitted_value_[dynamic_slice] = target_address; + return EmitMemcpy(*operand, *dynamic_slice); + } + return DefaultAction(dynamic_slice); +} + +Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* /*operand*/, + HloInstruction* update, + HloInstruction* /*start_indices*/) { + if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(dynamic_update_slice)); + emitted_value_[dynamic_update_slice] = target_address; + return EmitMemcpy(*update, *dynamic_update_slice); + } + return DefaultAction(dynamic_update_slice); +} + Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Recv is not implemented on CPU. See b/33942983."); @@ -1265,13 +1301,12 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { } } -Status IrEmitter::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { +Status IrEmitter::HandleCall(HloInstruction* call) { + HloComputation* computation = call->to_apply(); llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation); std::vector parameter_addresses; - for (HloInstruction* operand : operands) { + for (const HloInstruction* operand : call->operands()) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } @@ -1322,9 +1357,9 @@ Status IrEmitter::HandleCustomCall( return Status::OK(); } -Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) { +Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Precondition: Condition computation must return a scalar bool. + HloComputation* condition = xla_while->while_condition(); TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; @@ -1361,12 +1396,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init, })); // Set emitted value to that of 'init' with which it shares an allocation. + const HloInstruction* init = xla_while->operand(0); emitted_value_[xla_while] = GetEmittedValueFor(init); // The called computation should have been emitted previously. llvm::Function* condition_ir_function = FindOrDie(emitted_functions_, condition); - llvm::Function* body_ir_function = FindOrDie(emitted_functions_, body); + llvm::Function* body_ir_function = + FindOrDie(emitted_functions_, xla_while->while_body()); // Generating: // while (Condition(while_result)) { @@ -1710,8 +1747,7 @@ StatusOr IrEmitter::EmitTargetAddressForOp( llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); - retval->addAttr(llvm::AttributeList::get( - retval->getContext(), retval->getArgNo() + 1, attr_builder)); + retval->addAttrs(attr_builder); } return ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 66bae457e3741332f23abc7d54b8d775aa193ca9..b564b359b07a6ca52193bd0c5934f8563a00346c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -60,8 +60,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. - IrEmitter(const HloModule& hlo_module, const HloModuleConfig& module_config, - const BufferAssignment& assignment, llvm::Module* llvm_module, + IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx); ~IrEmitter() override; @@ -114,6 +114,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloComputation* function) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; Status HandleSend(HloInstruction* send) override; + Status HandleSlice(HloInstruction* slice, + HloInstruction* /*operand*/) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* /*operand*/, + HloInstruction* /*start_indices*/) override; + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* /*operand*/, + HloInstruction* /*update*/, + HloInstruction* /*start_indices*/) override; Status HandleRecv(HloInstruction* recv) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple( @@ -125,14 +134,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloComputation* function, tensorflow::gtl::ArraySlice static_operands) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 7a4723e8d75588d8ccb711892b4082024695e444..5f7b2c663f7a6a554afda17702160e70ce4e04a0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -58,12 +57,11 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr> function_names, std::unordered_map hlo_to_profile_idx, std::unordered_map> aligned_constants) - : Executable(std::move(hlo_module), std::move(module_config)), + : Executable(std::move(hlo_module)), jit_(std::move(jit)), assignment_(std::move(assignment)), functions_names_(std::move(function_names)), @@ -146,7 +144,7 @@ Status ParallelCpuExecutable::AllocateBuffers( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { @@ -160,7 +158,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { @@ -214,7 +212,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( void** temps_array = buffer_pointers.data(); uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool()); + auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); tensorflow::mutex completion_queue_lock; tensorflow::condition_variable completion_queue_cv; std::deque completion_queue; @@ -251,11 +249,12 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( }); auto function = FindOrDie(functions, instruction); // The thread pool entry takes ownership of |operand_buffers|. + const auto* exec_run_options = &run_options->run_options(); thread_pool->Schedule([instruction, &completion_queue, &completion_queue_lock, &completion_queue_cv, - result_buffer, run_options, operand_buffers, + result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array, function] { - function(result_buffer, run_options, operand_buffers, temps_array, + function(result_buffer, exec_run_options, operand_buffers, temps_array, profile_counters_array); delete[] operand_buffers; // Push the completed HLO instruction on the queue, the main thread @@ -345,9 +344,8 @@ ParallelCpuExecutable::ExecuteOnStream( const BufferAllocation::Index result_index = result_slice.index(); VLOG(3) << "result index: " << result_index; - TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(), - arguments, device_allocations, - hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions( + run_options, arguments, device_allocations, hlo_execution_profile)); // Mark the buffers that are actually live (used in the output) when the // computation finishes executing. @@ -400,8 +398,8 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); + TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, + hlo_execution_profile)); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer which is returned to the caller. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index 7223de9f0798365138cdb26ca9dce07cd0e474e3..a3278c9510e9661f53ecbc729aa500b3636d3f6d 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,7 +51,6 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr> instruction_functions, std::unordered_map hlo_to_profile_idx, std::unordered_map arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( - const ExecutableRunOptions* run_options, + const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 677080a8623224cdd65e35b3116ae57b7b3b3ca2..ee772f5c3967b6671f3d89c8ee3034e78501018b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -54,7 +54,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index 384a978873de89526f43556296aaa51c46ac1d3f..6f1c97a2334e08a5ea62b9b7837aa83fa3cde631 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -48,7 +48,7 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, int lhs_contract_dim = transpose_lhs ? 0 : 1; int rhs_contract_dim = transpose_rhs ? 1 : 0; const Eigen::array dims( - DimPair(lhs_contract_dim, rhs_contract_dim)); + {DimPair(lhs_contract_dim, rhs_contract_dim)}); // Matrix multiply is a special case of the "contract" operation where // the contraction is performed along dimension 1 of the lhs and dimension diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 8beb565ab3e220f9b9eebac836c8de8c1fc2e8ee..7c74912a7ab9c388c9911fe8194f268623f0abd1 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -112,13 +112,25 @@ llvm::SmallVector DetectMachineAttributes() { if (llvm::sys::getHostCPUFeatures(host_features)) { for (auto &feature : host_features) { if (feature.second) { - result.push_back(feature.first()); + llvm::StringRef feature_name = feature.first(); + // Skip avx512 for now, it isn't quite ready in LLVM. + if (feature_name.startswith("avx512")) { + continue; + } + result.push_back(feature_name); } } } return result; } +llvm::StringRef GetHostCpuName() { + auto cpu_name = llvm::sys::getHostCPUName(); + // Skip avx512 for now, it isn't quite ready in LLVM. + cpu_name.consume_back("-avx512"); + return cpu_name; +} + CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { CompilerFunctor::VectorIntrinsics intrinsics; intrinsics.sse_intrinsics = (&runtime::ExpV4F32 != nullptr); @@ -136,13 +148,16 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options, .setOptLevel(opt_level) .selectTarget( /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/llvm::sys::getHostCPUName(), + /*MCPU=*/GetHostCpuName(), /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, GetAvailableIntrinsics())) {} + opt_level, GetAvailableIntrinsics())) { + VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() + << " features: " << target_machine_->getTargetFeatureString().str(); +} SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 351efa82dd21dd9f618ed38cdb54bd2e26fcd5d5..49e9874cda2dd4cc5087b2467442d44bc0245734 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -189,19 +189,16 @@ class DfsHloVisitor { virtual Status HandleTranspose(HloInstruction* transpose) = 0; virtual Status HandleParameter(HloInstruction* parameter) = 0; virtual Status HandleFusion(HloInstruction* fusion) = 0; - virtual Status HandleCall( - HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) = 0; + virtual Status HandleCall(HloInstruction* call) = 0; virtual Status HandleCustomCall( HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) = 0; virtual Status HandleSlice(HloInstruction* slice, HloInstruction* operand) = 0; - virtual Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) = 0; + virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) = 0; virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, @@ -219,9 +216,7 @@ class DfsHloVisitor { const Window& window, HloComputation* function) = 0; virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0; - virtual Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, - HloComputation* body) = 0; + virtual Status HandleWhile(HloInstruction* xla_while) = 0; virtual Status HandlePad(HloInstruction* pad) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 18cfaf83e1cd558928c9fc65452524567f3cbb49..c27710fbdb2cb01776137370a61541c7e44c66c7 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -121,9 +121,7 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { Status HandleFusion(HloInstruction* fusion) override { return DefaultAction(fusion); } - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice /*operands*/, - HloComputation* /*computation*/) override { + Status HandleCall(HloInstruction* call) override { return DefaultAction(call); } Status HandleCustomCall( @@ -136,10 +134,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { HloInstruction* /*operand*/) override { return DefaultAction(slice); } - Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice /*operands*/) override { - return DefaultAction(slice); + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* /*operand*/, + HloInstruction* /*start_indices*/) override { + return DefaultAction(dynamic_slice); } Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* /*operand*/, @@ -188,9 +186,7 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { Status HandleTranspose(HloInstruction* transpose) override { return DefaultAction(transpose); } - Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/, - HloComputation* /*condition*/, - HloComputation* /*body*/) override { + Status HandleWhile(HloInstruction* xla_while) override { return DefaultAction(xla_while); } Status HandleSend(HloInstruction* send) override { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a04815dad94484a6f01ebd27d3ec73f547086722..bea1da4044669f5e910af09ba1b65416a69367b5 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -240,14 +240,18 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return ir_builder_->CreateFDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: return ir_builder_->CreateFRem(lhs_value, rhs_value); - - // The 'O' prefix on the LLVM ops means "ordered" compare where comparisons - // with NAN always return false. + // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered + // comparisons always return false when one of the operands is NaN, whereas + // unordered comparisons return true. + // + // We use ordered comparisons for everything except kNe, where we use an + // unordered comparison. This makes x != y equivalent to !(x == y), and + // matches C++'s semantics. case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, rhs_value, ir_builder_); case HloOpcode::kNe: - return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value, + return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, @@ -739,11 +743,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_from_operand", operand_idx), + "concat_index_from_operand", operand_idx), ir_builder_); auto false_block = llvm_ir::CreateBasicBlock( exit_block, tensorflow::strings::StrCat( - "concat_index_not_from_operand", operand_idx), + "concat_index_not_from_operand", operand_idx), ir_builder_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index eb36aba33a7694c43985b5e5636e7e0fa2ad4794..5a65f829fcd1e854c266b2d958a8f3d6408b87d4 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -41,10 +40,8 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module, - std::unique_ptr module_config) - : hlo_module_(std::move(hlo_module)), - module_config_(std::move(module_config)) {} + explicit Executable(std::unique_ptr hlo_module) + : hlo_module_(std::move(hlo_module)) {} virtual ~Executable() {} // Enqueues the compilation result on the provided stream, passing the given @@ -98,15 +95,17 @@ class Executable { // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. bool hlo_profiling_enabled() const { - return module_config_->hlo_profiling_enabled(); + return hlo_module_->config().hlo_profiling_enabled(); } const HloModule& module() const { return *hlo_module_; } - const HloModuleConfig& module_config() const { return *module_config_; } + const HloModuleConfig& module_config() const { return hlo_module_->config(); } // Returns whether this executable has an associated HloModuleConfig. - bool has_module_config() const { return module_config_ != nullptr; } + bool has_module_config() const { + return hlo_module_ != nullptr && hlo_module_->has_config(); + } // Returns the versioned computation handle of the computation computed by // this executable. @@ -117,7 +116,7 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. const Shape& result_shape() const { - return module_config_->entry_computation_layout().result_shape(); + return hlo_module_->config().entry_computation_layout().result_shape(); } // Dumping helpers. @@ -143,10 +142,6 @@ class Executable { // around. std::unique_ptr hlo_module_; - // The configuration used to build this executable (parameter layouts, result - // layout, profiling enabled, etc). - std::unique_ptr module_config_; - // SessionModule this was compiled from. Null if not dumping executions. std::unique_ptr session_module_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..297a4f7599f9c127386b2f53f7ffb987befc456e --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +// Helper to replace the called computation at a while- or call-instruction. +void ReplaceCalledComputation(HloInstruction* instruction, + HloComputation* computation, + HloComputation* new_computation) { + switch (instruction->opcode()) { + case HloOpcode::kWhile: { + if (computation == instruction->while_condition()) { + instruction->set_while_condition(new_computation); + } else { + CHECK_EQ(computation, instruction->while_body()); + instruction->set_while_body(new_computation); + } + break; + } + case HloOpcode::kCall: { + CHECK_EQ(instruction->to_apply(), computation); + instruction->set_to_apply(new_computation); + break; + } + default: + LOG(FATAL) << "unexpected opcode: " + << HloOpcodeString(instruction->opcode()); + } +} + +// Flatten a single call graph node. Expects to visit nodes in postorder. +Status FlattenNode(const CallGraphNode& node) { + HloComputation* computation = node.computation(); + HloModule* module = computation->parent(); + // Clone callee for all call-sites except the first one. + for (int i = 0; i < node.caller_callsites().size(); ++i) { + CallSite call_site = node.caller_callsites()[i]; + // Only consider sequential call contexts. + if (call_site.context() == CallContext::kParallel) { + continue; + } + CHECK_EQ(call_site.context(), CallContext::kSequential); + + // Skip first element if this computation is only called from a sequential + // context. + if (node.context() != CallContext::kBoth && i == 0) { + continue; + } + + // Clone computation for the remaining sequential context call sites. + HloComputation* clone = + module->AddEmbeddedComputation(computation->Clone()); + ReplaceCalledComputation(call_site.instruction(), computation, clone); + // Clone the sub-tree of all computations called from this node. + std::vector worklist; + worklist.push_back(clone); + while (!worklist.empty()) { + auto current = worklist.back(); + worklist.pop_back(); + for (auto& instruction : current->instructions()) { + if (GetInstructionCallContext(instruction.get()) != + CallContext::kSequential) { + continue; + } + for (auto callee : instruction->called_computations()) { + HloComputation* callee_clone = + module->AddEmbeddedComputation(callee->Clone()); + ReplaceCalledComputation(instruction.get(), callee, callee_clone); + worklist.push_back(callee_clone); + } + } + } + } + return Status::OK(); +} + +} // namespace + +StatusOr FlattenCallGraph::Run(HloModule* module) { + XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode)); + + XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString()); + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..d3efab3614912e4b0c2c8aa3b80277c326382ed0 --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Flatten the call graph for an HLO module into a tree. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Flattening associates each call site with a unique computation (for +// sequential calling contexts) This simplifies buffer assignment and +// points-to analysis (see b/36865746 for details). +class FlattenCallGraph : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + + // Duplicates computations called from multiple call- or while-nodes to + // flatten the call graph. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e03a96fb3f03710cd3062a79aa4955311cf19c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -0,0 +1,227 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class FlattenCallGraphTest : public HloTestBase { + protected: + // Build and return a trivial computation taking and returning a scalar. + std::unique_ptr MakeScalarComputation() { + HloComputation::Builder builder(TestName() + ".ScalarComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + return builder.Build(); + } + + // Build and return a computation which takes a scalar and maps (kMap) the + // given computation to the value 'callsites' number of times. + std::unique_ptr MakeMappingComputation( + HloComputation* map_computation, int64 callsites) { + HloComputation::Builder builder(TestName() + ".MappingComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateMap( + kScalarShape, {last_value}, map_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and calls (kCall) the + // given computation with value 'callsites' number of times. + std::unique_ptr MakeCallingComputation( + HloComputation* callee_computation, int64 callsites, + const string& suffix = ".CallingComputation") { + HloComputation::Builder builder(TestName() + suffix); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* last_value = param0; + for (int64 i = 0; i < callsites; ++i) { + last_value = builder.AddInstruction(HloInstruction::CreateCall( + kScalarShape, {last_value}, callee_computation)); + } + return builder.Build(); + } + + // Build and return a computation which takes a scalar and returns a PRED + // value. + std::unique_ptr MakeConditionComputation() { + HloComputation::Builder builder(TestName() + ".ConditionComputation"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); + return builder.Build(); + } + + StatusOr RunFlattenCallGraph(HloModule* module) { + FlattenCallGraph flatten; + TF_ASSIGN_OR_RETURN(bool result, flatten.Run(module)); + return result; + } + + const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(FlattenCallGraphTest, ComplexGraph) { + // Test a call graph of a module with several computation called in various + // contexts. The call graph looks like: + // + // entry + // / | + // a | + // / | \ | + // b | cond + // \ | + // c + // + // Calls are made via kCall, kWhile, and kMap instructions. + HloModule module(TestName()); + HloComputation* cond_computation = + module.AddEmbeddedComputation(MakeConditionComputation()); + HloComputation* c_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* b_computation = module.AddEmbeddedComputation( + MakeMappingComputation(c_computation, /*callsites=*/1)); + + HloComputation* a_computation; + { + HloComputation::Builder builder(TestName() + ".a"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, b_computation, call)); + a_computation = module.AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, a_computation, param0)); + entry_computation = module.AddEntryComputation(builder.Build()); + } + + { + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); + EXPECT_TRUE(result); + std::unique_ptr flat_call_graph = CallGraph::Build(&module); + const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); + } +} + +// Test corner case of a computation used as a body and a loop condition. +TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { + HloModule module(TestName()); + HloComputation* cond_computation; + { + HloComputation::Builder builder(TestName() + ".cond"); + HloInstruction* param0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(PRED, {}), "param0")); + HloInstruction* false_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kEq, param0, false_constant)); + cond_computation = module.AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* false_constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateWhile( + ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, + false_constant)); + entry_computation = module.AddEntryComputation(builder.Build()); + } + + { + std::unique_ptr call_graph = CallGraph::Build(&module); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(2, cond_node.caller_callsites().size()); + } + + { + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(&module); + const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); + EXPECT_EQ(1, cond_node.caller_callsites().size()); + } +} + +// Test flattening of a nested calling computations. +// +// Entry +// / \ +// \ / +// B +// / \ +// \ / +// C +// +TEST_F(FlattenCallGraphTest, FlattenCalls) { + HloModule module(TestName()); + HloComputation* c_computation = + module.AddEmbeddedComputation(MakeScalarComputation()); + + HloComputation* b_computation = module.AddEmbeddedComputation( + MakeCallingComputation(c_computation, /*callsites=*/2, ".B")); + + module.AddEntryComputation( + MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); + + TF_ASSIGN_OR_ASSERT_OK(bool result, RunFlattenCallGraph(&module)); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(&module); + EXPECT_EQ(7, module.computations().size()); + + const CallGraphNode& c_node = call_graph->GetNode(c_computation); + EXPECT_EQ(1, c_node.caller_callsites().size()); + + const CallGraphNode& b_node = call_graph->GetNode(b_computation); + EXPECT_EQ(1, b_node.caller_callsites().size()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9de6d65a27bfcb6747d59eac75f8b13debba0ebd..d26f415fd4bdfec597c70b760942cc406a0d6cfa 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -264,6 +264,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/platform/default/build_config:cublas_plugin", + "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], ) @@ -425,6 +427,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -529,14 +532,10 @@ cc_test( deps = [ ":instruction_fusion", ":while_transformer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:lib", - "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index f6b7fe1e8ef10e4e66018d887707e587ecfa3465..94acf5a35945a33048038bfae67d46c38a07ef8d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -125,7 +125,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( CHECK_LE(num_dimensions, 3); // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behaviour of TF (see definition of conv1d in + // This matches the behavior of TF (see definition of conv1d in // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1667ab36792c91cbbf3c6396a673bedff2208045..e57eb0bdee64948290d5eaf15965afcdc8bea0ad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -113,7 +113,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { - // Binary math functions tranform are of type [T] -> T. + // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f692f28bd9858ab809732389fcc2908b8fa66a42..b616d958b96c41e9b9021bf375d51a32ef73ceb9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" #include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -133,8 +134,13 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pass.AddPass(); } pipeline.AddPass(); - pipeline.AddPass(ImplementedAsGemm); - pipeline.AddPass(); + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return ImplementedAsGemm(dot) ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); @@ -172,16 +178,20 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). + // instruction which materializes a value). DCE must be run immediately before + // (and sometime after) copy insertion, to avoid dead code from interfering + // with the rewrites. + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } // Invokes the ptxas tool on the given PTX string, and dumps its output. void DumpPtxasInfo(const string& ptx) { - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string ptxas_path = flags->xla_ptxas_path; + const string ptxas_path = + tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { LOG(WARNING) @@ -222,15 +232,14 @@ GpuCompiler::GpuCompiler() pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} StatusOr> GpuCompiler::Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(hlo_module.get(), dump_hlo, + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), dump_hlo, stream_exec->GetDeviceDescription())); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, hlo_module.get(), - module_config.get())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, module.get(), + module->mutable_config())); llvm::LLVMContext llvm_context; std::string buffer; @@ -243,7 +252,7 @@ StatusOr> GpuCompiler::Compile( }; llvm_context.setDiagnosticHandler(DiagnosticHandler, &printer); - llvm::Module llvm_module(hlo_module->name().c_str(), llvm_context); + llvm::Module llvm_module(module->name().c_str(), llvm_context); // Set the target triple and the data layout. llvm_module.setTargetTriple(kTargetTriple); llvm_module.setDataLayout(kDataLayout); @@ -251,29 +260,28 @@ StatusOr> GpuCompiler::Compile( // Determine the HLO schedule, which is an ordering of HLO instructions. This // is used by buffer assignment to enable buffer reuse, and the same ordering // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = - AssignStreams(*hlo_module); + std::unique_ptr stream_assignment = AssignStreams(*module); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_schedule, - HloSchedule::Build(*hlo_module, *stream_assignment, pointer_size_)); + HloSchedule::Build(*module, *stream_assignment, pointer_size_)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, - BufferAssigner::Run(hlo_module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), [this](const LogicalBuffer& buffer) { return ShapeSizeBytes(buffer.shape()); }, kMemoryAlignment)); - IrEmitterContext ir_emitter_context(hlo_module.get(), buffer_assignment.get(), + IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), &stream_exec->GetDeviceDescription(), &llvm_module); - HloComputation* entry_computation = hlo_module->entry_computation(); - IrEmitterUnnested ir_emitter(*module_config, entry_computation, - module_config->has_hybrid_result(), + HloComputation* entry_computation = module->entry_computation(); + IrEmitterUnnested ir_emitter(module->config(), entry_computation, + module->config().has_hybrid_result(), &ir_emitter_context); TF_RETURN_IF_ERROR( entry_computation->root_instruction()->Accept(&ir_emitter)); @@ -302,7 +310,7 @@ StatusOr> GpuCompiler::Compile( cc_minor = 0; } TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, - *module_config, libdevice_dir_)); + module->config(), libdevice_dir_)); VLOG(2) << "LLVM module after optimizations:"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); @@ -319,8 +327,8 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, thunk_schedule->ToString()); auto* gpu_executable = - new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(hlo_module), - std::move(module_config), std::move(buffer_assignment)); + new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), + std::move(buffer_assignment)); if (flags->xla_gpu_embed_ir) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -329,9 +337,8 @@ StatusOr> GpuCompiler::Compile( } StatusOr>> GpuCompiler::Compile( - std::vector> hlo_modules, - std::vector> module_configs, - HloDumper dump_hlos, std::vector stream_execs) { + std::vector> modules, HloDumper dump_hlos, + std::vector stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } @@ -339,7 +346,6 @@ StatusOr>> GpuCompiler::Compile( StatusOr>> GpuCompiler::CompileAheadOfTime( std::vector> module, - std::vector> module_config, HloDumper dump_hlo, const AotCompilationOptions& options) { return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 22f492b42294838bf323b70f492d83fa9c7b4ce2..921d683f03066a57bbadeacb6e33c91cadb3c095 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -42,20 +42,16 @@ class GpuCompiler : public Compiler { ~GpuCompiler() override {} StatusOr> Compile( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, + std::unique_ptr module, HloDumper dump_hlo, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( - std::vector> hlo_module, - std::vector> module_config, - HloDumper dump_hlo, + std::vector> modules, HloDumper dump_hlo, std::vector stream_exec) override; StatusOr>> CompileAheadOfTime( std::vector> module, - std::vector> module_config, HloDumper dump_hlo, AotCompilationOptions const& options) override; perftools::gputools::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 32f0368b4bc523d3d81147a8cbbde745387c21d4..69bcd53e05d5de013be2af1cfdba934cea34af6b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -110,9 +110,8 @@ class HloExecutionProfiler { GpuExecutable::GpuExecutable(tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr assignment) - : Executable(std::move(hlo_module), std::move(module_config)), + : Executable(std::move(hlo_module)), ptx_(ptx), thunk_schedule_(std::move(thunk_schedule)), assignment_(std::move(assignment)) {} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e308de79ba582d3497e7f217285ae4b1ed0be1a7..ad178b7249e4a265ca88a45985142b08b1023417 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -51,7 +50,6 @@ class GpuExecutable : public Executable { GpuExecutable(tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, - std::unique_ptr module_config, std::unique_ptr assignment); // This should be called after set_ir_module_string. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 34a44ad40548272a0c2a87efadfa1ab2aca7b979..a36dcbbd2faf3258ec2790f51bb2aec3ce834a6c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -46,6 +46,11 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Output fusion is not currently supported on GPUs. + if (producer->opcode() == HloOpcode::kFusion) { + return false; + } + // RNG operations are not currently parallel-friendly on GPU. if (producer->opcode() == HloOpcode::kRng) { return false; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index e8378a7f447cebf8d491e98595188d2391333c58..c6e8a2f78b5a398d9e9d5a684ac4d42520ec20c8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,11 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -85,6 +90,11 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { } bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { + // We can only do this if the HLO is unnested. + if (hlo.parent() != hlo.GetModule()->entry_computation()) { + return false; + } + // Forward convolution. if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 4d3e9b10b2e69b083d74cf7b56edc5b781991b55..e8c68a6ef72ede8f2f3dd2279a8e43468ce8f35d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -25,16 +25,7 @@ limitations under the License. namespace xla { namespace gpu { -const int64 kWarpSize = 32; - -// Precondition: "hlo" is an operand of a Dot instruction. -// -// Returns whether "hlo" is foldable to its user. -bool IsOperandFoldableToDot(const HloInstruction& hlo); - -// Returns true if GpuCompiler can fold any operands of "dot" into "dot" for -// better performance. -bool CanFoldOperandsIntoDot(const HloInstruction& dot); +constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 5f3ce85f857a96ca0cca6b0bea4bf1e86b971827..36619a845413b19ec2d559252409dae1b96b76e4 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -399,7 +399,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, llvm::Type* accum_type = target_array.GetElementLlvmType(); llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry( accum_type, // The pointee type of the alloca instruction. - "accum_address", // The name of the alloca instuction. + "accum_address", // The name of the alloca instruction. &ir_builder_); // Initialize the accumulator in the preheader to zero. @@ -549,14 +549,12 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator()); } -Status IrEmitter::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { +Status IrEmitter::HandleCall(HloInstruction* call) { std::vector operand_addresses; - for (HloInstruction* operand : operands) { + for (HloInstruction* operand : call->operands()) { operand_addresses.push_back(GetBasePointer(*operand)); } - return EmitCallToNestedComputation(*computation, operand_addresses, + return EmitCallToNestedComputation(*call->to_apply(), operand_addresses, GetBasePointer(*call)); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 1aefee2739978ec05f4094f79acaece39e221bea..513bead62d8db38e550bc550fe2212b6e5dc4baf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -101,9 +101,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* on_true, HloInstruction* on_false) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; @@ -249,8 +247,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleTuple( HloInstruction* tuple, tensorflow::gtl::ArraySlice operands) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; Status HandleRng(HloInstruction* random, RandomDistribution distribution) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 9b7aa7c860b14e03c238bd7037f0df832eacfef3..e52e55a1a8199019e2c149a777a4e948f830ce0e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( ir_emitter_context_->buffer_assignment().GetTempAllocation()) { kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size()); } - kernel->setDoesNotAlias(temp_buffer_arg_no + 1); + kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. @@ -1540,10 +1540,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( .EmitLoop(); } -Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while, - HloInstruction* init, - HloComputation* condition, - HloComputation* body) { +Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { + HloComputation* condition = xla_while->while_condition(); TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) && condition->root_instruction()->shape().element_type() == PRED) << "While condition computation must return bool"; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 485216837dc727bfe8565ff22678dd2fa470bc40..4f34cb77b0390e21350ad146695dd5be67fdabbf 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -396,7 +396,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // The LLVM IR verifier performs sanity checking on the IR. This helps // discover problems and report them in a meaningful manner, rather than let - // later passes report obscure assertions becasue of unfulfilled invariants. + // later passes report obscure assertions because of unfulfilled invariants. module_passes.add(llvm::createVerifierPass()); // Create the function-level pass manager. It needs data layout information @@ -405,9 +405,9 @@ StatusOr CompileModuleToPtx(llvm::Module* module, AddOptimizationPasses(flags->opt_level, /*size_level=*/0, target_machine.get(), &module_passes, &function_passes); - // Loop unrolling exposes more opportunites for SROA. Therefore, we run SROA + // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. - // TODO(jingyue): SROA may further expose more optimization opportunites, such + // TODO(jingyue): SROA may further expose more optimization opportunities, such // as more precise alias analysis and more function inlining (SROA may change // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index c10346bbc235d8949525eb2008bac5312395381d..72f6cfd2d60712bb74af3dca2041ed1413004d23 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -28,7 +28,8 @@ limitations under the License. namespace { static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) { - LOG(FATAL) << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() + LOG(FATAL) << diagnostic->getFilename().str() << ":" + << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() << ": " << diagnostic->getMessage().str(); } diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8ac4c5996632587fe4518df5560a1a74d9e8caa6..8f7fce884acc93fd39510ad0826b819a6d9731a7 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -33,7 +33,7 @@ namespace gpu { enum class PartitionStrategy { // Optimized for latency by allowing maximum number of registers per thread. kLatency, - // Optimized for throughtput. This may limit registers per thread and cause + // Optimized for throughput. This may limit registers per thread and cause // longer latency. kThroughput }; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ec75e1358142764d80152a6d8abbc6d5b72acb9a..61a9e7e9e1bfca3b73e427ef6bbb956aee51c2e7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -37,7 +37,7 @@ namespace { // patterns to match. // // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifing the index and HloOpcode +// of type ExprTree). Operands can be added by specifying the index and HloOpcode // of the operand. // // For example, the following computation: @@ -122,10 +122,12 @@ class ExprTree { Status Match(const HloInstruction* instruction, TaggedInstructionMap* tagged_instructions) const { if (opcode_ != instruction->opcode()) { - return InvalidArgument("Unexpected opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InvalidArgument("got opcode %s, want %s", + HloOpcodeString(instruction->opcode()).c_str(), + HloOpcodeString(opcode_).c_str()); } + VLOG(2) << "Matched " << HloOpcodeString(opcode_) << ": " << tag_; if (!tag_.empty()) { tagged_instructions->insert({tag_, instruction}); } @@ -166,7 +168,7 @@ class MatcherBase { virtual ~MatcherBase() {} // Attempts to match each ExprTree in 'expr_trees_'. - // Returns OK on the first succesful match, error status otherwise. + // Returns OK on the first successful match, error status otherwise. virtual tensorflow::Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { @@ -275,6 +277,7 @@ class WhileConditionComputationMatcher : public MatcherBase { } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while condition"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), &tagged_instructions)); @@ -344,10 +347,6 @@ class WhileInitOperandMatcher : public MatcherBase { // // Const // | - // Tuple1 - // | - // GTE0 - // | // Copy // | // Tuple0 @@ -355,15 +354,15 @@ class WhileInitOperandMatcher : public MatcherBase { // While // ExprTree BuildInitExprTree() { - ExprTree gte0(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kConstant, "loop_start"))); - return ExprTree(HloOpcode::kWhile, "while", - ExprTree(HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, gte0))); + return ExprTree( + HloOpcode::kWhile, "while", + ExprTree(HloOpcode::kTuple, tuple_index_, + ExprTree(HloOpcode::kCopy, + ExprTree(HloOpcode::kConstant, "loop_start")))); } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while init"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(while_hlo_, &tagged_instructions)); @@ -375,14 +374,6 @@ class WhileInitOperandMatcher : public MatcherBase { while_hlo->name().c_str()); } - // Get tagged GTE instruction and check 'tuple_index_'. - TF_ASSIGN_OR_RETURN(const HloInstruction* gte, - GetTaggedInstruction("gte", tagged_instructions)); - if (gte->tuple_index() != tuple_index_) { - return InvalidArgument("Unexpected tuple index instruction : %s", - gte->name().c_str()); - } - // Get tagged Constant instruction and parse 'loop_start_'. TF_ASSIGN_OR_RETURN( const HloInstruction* const_hlo, @@ -427,10 +418,6 @@ class WhileBodyComputationMatcher : public MatcherBase { // \ / \ / // Fusion -----------> Add // | - // Tuple1 - // | - // GTE0 - // | // Copy // | // Tuple0 @@ -450,15 +437,13 @@ class WhileBodyComputationMatcher : public MatcherBase { fusion.SetFusedRoot(fused_root); // Build top-level computation. - ExprTree tuple0( - HloOpcode::kTuple, tuple_index_, - ExprTree(HloOpcode::kCopy, - ExprTree(HloOpcode::kGetTupleElement, "gte", - ExprTree(HloOpcode::kTuple, tuple_index_, fusion)))); + ExprTree tuple0(HloOpcode::kTuple, tuple_index_, + ExprTree(HloOpcode::kCopy, fusion)); return tuple0; } Status MatchExprTree(const ExprTree& expr_tree) override { + VLOG(2) << "MATCHING while body"; ExprTree::TaggedInstructionMap tagged_instructions; TF_RETURN_IF_ERROR(expr_tree.Match(computation_->root_instruction(), &tagged_instructions)); diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index ddf9676e378c5445418d30ae767d19ef2fb74be8..a315b9ad11a4a15d4c4d624320283d4467e9bf41 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,12 +17,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { namespace { +using ::testing::Eq; +using ::testing::HasSubstr; + class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() @@ -135,12 +139,10 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_TRUE(result.ok()); + ASSERT_TRUE(result.ok()); // Check results. - auto tuple = result.ConsumeValueOrDie(); - EXPECT_EQ(0, std::get<0>(tuple)); - EXPECT_EQ(10, std::get<1>(tuple)); - EXPECT_EQ(1, std::get<2>(tuple)); + EXPECT_THAT(result.ConsumeValueOrDie(), + Eq(std::tuple(0, 10, 1))); } TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { @@ -154,12 +156,10 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_TRUE(result.ok()); + ASSERT_TRUE(result.ok()); // Check results. - auto tuple = result.ConsumeValueOrDie(); - EXPECT_EQ(0, std::get<0>(tuple)); - EXPECT_EQ(10, std::get<1>(tuple)); - EXPECT_EQ(1, std::get<2>(tuple)); + EXPECT_THAT(result.ConsumeValueOrDie(), + Eq(std::tuple(0, 10, 1))); } TEST_F(WhileTransformerTest, InvalidLoopLimit) { @@ -173,10 +173,9 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_FALSE(result.ok()); - EXPECT_MATCH( - result.status().error_message(), - testing::ContainsRegex("Loop start must be less than loop limit.")); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + HasSubstr("Loop start must be less than loop limit.")); } TEST_F(WhileTransformerTest, InvalidLoopIncrement) { @@ -190,10 +189,9 @@ TEST_F(WhileTransformerTest, InvalidLoopIncrement) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - EXPECT_FALSE(result.ok()); - EXPECT_MATCH( - result.status().error_message(), - testing::ContainsRegex("Loop increment must greater than zero.")); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + HasSubstr("Loop increment must greater than zero.")); } } // namespace diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 46c0d8edead1eaba518fd1040b7dd7d0d6c79159..645c68e0438f875e9c4c560b875a18a71618e61c 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -53,12 +53,44 @@ std::vector UniqueOperandSourceBuffers( /*static*/ StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + const HloComputation* entry_computation = module.entry_computation(); + const std::vector& instruction_sequence = + FindOrDie(module_sequence, entry_computation); + TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation, + instruction_sequence, + points_to_analysis, &module_sequence)); + return heap.Finish(); +} + +/*static*/ +StatusOr HeapSimulator::Run( + std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, - const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const FlatSet* buffers_to_assign) { + HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + points_to_analysis, + /*module_sequence=*/nullptr)); + return heap.Finish(); +} + +// Runs a heap simulation for the given 'computation', assuming the given +// 'instruction_sequence'. If 'module_sequence' is non-null, it is used to find +// kCall and kWhile sub-computations, and the heap simulation for those +// sub-computations will be run recursively. +Status HeapSimulator::RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence) { // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -67,25 +99,29 @@ StatusOr HeapSimulator::Run( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign); FlatMap> live_buffers; + const HloInstruction* root = computation.root_instruction(); + FlatSet output_source_buffers = + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + for (const HloInstruction* instruction : instruction_sequence) { const std::vector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); - const HloInstruction* root = computation.root_instruction(); - FlatSet output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); - // Initialize live_buffers for each buffer that we're going to assign. The // set of instructions that need to be visited contains all users of all // aliases. The alias itself is not necessary; if it has users, the users // are necessarily scheduled after the alias. And if it has no users, it is // either a dead value or an output, both of which are handled below. + // + // We ignore control dependencies here. The reasoning is that the control + // dependencies have already been accounted for in the ordering of the given + // 'instruction_sequence', and should not otherwise artificially extend the + // lifetime of buffers that aren't already connected by a data dependency. std::vector dead_buffers_to_free; for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } for (const BufferAlias& alias : @@ -122,7 +158,7 @@ StatusOr HeapSimulator::Run( std::vector operand_buffers_to_free; for (const LogicalBuffer* operand_buffer : UniqueOperandSourceBuffers(instruction, points_to_analysis)) { - if (heap.IgnoreBuffer(operand_buffer)) { + if (IgnoreBuffer(operand_buffer)) { continue; } live_buffers[operand_buffer].erase(instruction); @@ -137,10 +173,10 @@ StatusOr HeapSimulator::Run( // happen before dead or operand buffers are freed; the instruction reads // the operand buffers to produce its output. // - // INVARIANT: Either heap.Alloc or heap.ShareBuffer will be called for each - // buffer that we should assign. + // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer + // that we should assign. for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { - if (heap.IgnoreBuffer(buffer)) { + if (IgnoreBuffer(buffer)) { continue; } @@ -151,27 +187,54 @@ StatusOr HeapSimulator::Run( bool shared = false; for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + buffer->instruction()->opcode() != HloOpcode::kCopy && CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { - heap.ShareBuffer(buffer, operand_buffer); + ShareBuffer(buffer, operand_buffer); shared = true; break; } } if (!shared) { - heap.Alloc(buffer); + Alloc(buffer); } } + // If the whole module is sequential, we can save memory by running the + // heap-simulation for sub-computations inline. E.g. the buffers for the + // condition and body of a kWhile instruction are only live for the duration + // of the instruction itself. + // + // The order that the sub-computations are simulated does not affect + // correctness; since the whole module is sequential, we know that the + // sub-computations will never be run concurrently. + if (module_sequence != nullptr) { + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kWhile) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + const std::vector& called_sequence = + FindOrDie(*module_sequence, called_computation); + TF_RETURN_IF_ERROR(RunComputation(*called_computation, + called_sequence, points_to_analysis, + module_sequence)); + } + } + + // Other sub-computations (e.g. Map, Reduce, ...) are skipped; they are + // assigned "thread-local" allocations, meaning their buffers are not + // allocated up-front at the beginning of the computation. + } + // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { - heap.Free(buffer); + Free(buffer); } } @@ -182,10 +245,10 @@ StatusOr HeapSimulator::Run( const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; - heap.Free(buffer); + Free(buffer); } - return heap.Finish(); + return Status::OK(); } HeapSimulator::HeapSimulator( @@ -304,6 +367,11 @@ HeapSimulator::Result HeapSimulator::Finish() { result.chunk_map.emplace(buffer, chunk); } } + // If we were told to assign specific buffers, make sure we've assigned + // exactly that many buffers. + if (buffers_to_assign_ != nullptr) { + CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + } } // Fragmentation is the difference between the actual and ideal sizes. diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 0ce2906767898bcace45e296d76f958c50a2b3a7..3d98046261902b41a17a8ab0f9a349634a1e4545 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,17 +64,32 @@ class HeapSimulator { }; // Run the heap simulation with the given algorithm, assuming the given - // sequential ordering of instructions. The 'instruction_sequence' must - // contain a topologically-consistent total ordering of all instructions in - // the computation. The result is invalid if instructions are not run in - // exactly this sequence. + // module_sequence, which must contain a topologically-consistent total + // ordering of all instructions within each computation. The result is invalid + // if instructions are not run in exactly this sequence. + // + // Running heap simulation on the whole module tends to save memory, compared + // to running on a per-computation basis, since we can re-use buffer space for + // called sub-computations. // // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. + static StatusOr Run( + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_fn, + const tensorflow::gtl::FlatSet* buffers_to_assign = + nullptr); + + // Same as above, but runs on a single computation. The 'instruction_sequence' + // must contain a topologically-consistent total ordering of all instructions + // in the computation. The result is invalid if instructions are not run in + // exactly this sequence. static StatusOr Run( std::unique_ptr algorithm, - const std::vector& instruction_sequence, const HloComputation& computation, + const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, const tensorflow::gtl::FlatSet* buffers_to_assign = @@ -86,6 +102,12 @@ class HeapSimulator { const tensorflow::gtl::FlatSet* buffers_to_assign); ~HeapSimulator(); + Status RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const SequentialHloOrdering::HloModuleSequence* module_sequence); + bool IgnoreBuffer(const LogicalBuffer* buffer) const; void Alloc(const LogicalBuffer* buffer); void Free(const LogicalBuffer* buffer); diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 874bd5f1060c179d5547510c351909069aa935b8..0a6900f73304f7a7b1209807fd3a1e8220484e03 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,13 +19,16 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -69,6 +72,7 @@ class HeapCallRecorder : public HeapAlgorithm { // sequence against an expected sequence. class HeapSimulatorTracker { public: + // Constructor for testing a single entry computation. HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { @@ -83,12 +87,48 @@ class HeapSimulatorTracker { auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), instruction_sequence, - *module_->entry_computation(), - *points_to_analysis_, zero_size) + result_ = HeapSimulator::Run( + std::move(algorithm), *module_->entry_computation(), + instruction_sequence, *points_to_analysis_, zero_size) .ConsumeValueOrDie(); } + explicit HeapSimulatorTracker(const string& name) { + module_ = MakeUnique(name); + } + + // Similar to the single entry computation constructor above, but runs the + // simulation over the entire module. + void RunWholeModule( + const std::vector& full_module_sequence) { + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + + // Construct the module sequence grouped by computation. + SequentialHloOrdering::HloModuleSequence module_sequence; + tensorflow::gtl::FlatMap reverse_position; + for (int i = 0; i < full_module_sequence.size(); ++i) { + const HloInstruction* instruction = full_module_sequence[i]; + module_sequence[instruction->parent()].push_back(instruction); + reverse_position[instruction] = full_module_sequence.size() - i; + } + + // Hack the size_fn so that it returns a decreasing value as we step through + // the sequence. This lets us ensure the Alloc calls are in the sequence + // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // deterministic. + auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + return reverse_position[buffer.instruction()]; + }; + auto algorithm = MakeUnique( + MakeUnique(&actual_calls_)); + result_ = HeapSimulator::Run(std::move(algorithm), *module_, + module_sequence, *points_to_analysis_, size_fn) + .ConsumeValueOrDie(); + } + + HloModule* module() { return module_.get(); } + // Returns the buffer defined at the given instruction and index. const LogicalBuffer* BufferAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -358,6 +398,86 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, WholeModule) { + HeapSimulatorTracker tracker(TestName()); + + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + tracker.module()->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + tracker.module()->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, param)); + tracker.module()->AddEntryComputation(builder.Build()); + + tracker.RunWholeModule( + {param, while_op, body_param, cond_param, cond_iter, cond_data, cond_lt}); + tracker.ExpectCallSequence({ + // The entry computation param and while_op are allocated first. + {kAlloc, tracker.BufferAt(param, {})}, + {kAlloc, tracker.BufferAt(param, {0})}, + {kAlloc, tracker.BufferAt(param, {1})}, + {kAlloc, tracker.BufferAt(while_op, {})}, + {kAlloc, tracker.BufferAt(while_op, {0})}, + {kAlloc, tracker.BufferAt(while_op, {1})}, + + // Now the while body param is allocated and freed. + {kAlloc, tracker.BufferAt(body_param, {})}, + {kAlloc, tracker.BufferAt(body_param, {0})}, + {kAlloc, tracker.BufferAt(body_param, {1})}, + {kFree, tracker.BufferAt(body_param, {})}, + {kFree, tracker.BufferAt(body_param, {0})}, + {kFree, tracker.BufferAt(body_param, {1})}, + + // Now the while cond param is allocated. The GTE instructions just alias + // the param elements, so the param tuple can immediately be freed. + {kAlloc, tracker.BufferAt(cond_param, {})}, + {kAlloc, tracker.BufferAt(cond_param, {0})}, + {kAlloc, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_param, {})}, + + // Now the final cond less-than buffer is allocated. + {kAlloc, tracker.BufferAt(cond_lt, {})}, + + // The order of the remaining Free calls is based on the LogicalBuffer.id, + // which is deterministic, but not obvious. + {kFree, tracker.BufferAt(param, {})}, + {kFree, tracker.BufferAt(param, {0})}, + {kFree, tracker.BufferAt(param, {1})}, + + {kFree, tracker.BufferAt(while_op, {})}, + {kFree, tracker.BufferAt(while_op, {0})}, + {kFree, tracker.BufferAt(while_op, {1})}, + + {kFree, tracker.BufferAt(cond_param, {0})}, + {kFree, tracker.BufferAt(cond_param, {1})}, + {kFree, tracker.BufferAt(cond_lt, {})}, + + {kFinish, nullptr}, + }); +} + // Base class for heap algorithm tests. class HeapAlgorithmTestBase : public ::testing::Test { protected: diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 35f8dcb7ca614f5660850c9022049eea908f323c..2584ad39ae1c58c187d00985919a39dd184c9c63 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -35,10 +35,14 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::strings::StrCat; + std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { int parameter_count = 0; @@ -52,16 +56,17 @@ std::unique_ptr HloComputation::Builder::Build( root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique( - new HloComputation(name_, parameter_count, &instructions_, root)); + return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, + root, is_fusion_computation_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction) + HloInstruction* root_instruction, bool is_fusion_computation) : name_(name), root_instruction_(root_instruction), + is_fusion_computation_(is_fusion_computation), instruction_name_uniquer_(/*separator=*/".") { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; @@ -90,8 +95,7 @@ HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstructionInternal( std::unique_ptr instruction) { // Generate a unique name for the instruction. - instruction->set_name( - instruction_name_uniquer_.GetUniqueName(instruction->name())); + instruction->UniquifyName(&instruction_name_uniquer_); Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = @@ -99,19 +103,77 @@ HloInstruction* HloComputation::AddInstructionInternal( return pinst; } -void HloComputation::Reparent(HloInstruction* instruction) { +HloInstruction* HloComputation::AddParameter( + std::unique_ptr instruction) { + CHECK(instruction->opcode() == HloOpcode::kParameter); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + instruction->SetParentFusion(root_instruction_->fusion_instruction()); + CHECK(root_instruction_->fusion_instruction()->operand_count() == + param_instructions_.size()); instruction->set_parent(this); - if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& i : instruction->fused_instructions()) { - Reparent(i.get()); + param_instructions_.push_back(instruction.get()); + AddInstructionInternal(std::move(instruction)); + return instructions_.back().get(); +} + +Status HloComputation::RemoveParameter(int64 param_no) { + CHECK_GE(param_no, 0); + CHECK_LT(param_no, param_instructions_.size()); + CHECK(is_fusion_computation_); + CHECK(root_instruction_->fusion_instruction() != nullptr); + HloInstruction* param_instruction = param_instructions_[param_no]; + auto param_instruction_iterator = param_instructions_.begin() + param_no; + param_instructions_.erase(param_instruction_iterator); + // Throw removed fused parameter instruction away. + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + + while (param_no < param_instructions_.size()) { + param_instruction = param_instructions_[param_no]; + string param_name = param_instruction->parameter_name(); + // Fusion parameters are named foo.param_1, bar.param_2, etc. We are + // renumbering the parameters so replace the final number in the name with + // the updated value. + const string param_underscore = ".param_"; + size_t index = param_name.rfind(param_underscore); + if (index == string::npos) { + string after_param = name().substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + param_name = + StrCat(param_name.substr(0, index), param_underscore, param_no); + } } + + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + new_instr->SetParentFusion(root_instruction_->fusion_instruction()); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + param_no++; } + + return Status::OK(); } -/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) { - return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv || - opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace || - opcode == HloOpcode::kOutfeed); +void HloComputation::Reparent(HloInstruction* instruction) { + instruction->set_parent(this); +} + +bool HloComputation::IsRemovable(const HloInstruction* instruction) { + // If the instruction has control predecessors or successors then we cannot + // remove the instruction without violating ordering constraints (added, for + // example, to avert interference due to buffer aliasing). + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + const HloOpcode opcode = instruction->opcode(); + return !((opcode == HloOpcode::kParameter && !is_fusion_computation_) || + opcode == HloOpcode::kRecv || opcode == HloOpcode::kSend || + opcode == HloOpcode::kTrace || opcode == HloOpcode::kOutfeed); } Status HloComputation::RemoveInstructionAndUnusedOperands( @@ -119,7 +181,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->user_count() == 0); - TF_RET_CHECK(HloComputation::IsRemovable(instruction->opcode())); + TF_RET_CHECK(IsRemovable(instruction)); std::unordered_set removed; std::queue worklist; worklist.push(instruction); @@ -128,8 +190,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || - !HloComputation::IsRemovable(item->opcode())) { + item == root_instruction() || !IsRemovable(item)) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -145,7 +206,7 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( Status HloComputation::RemoveInstruction(HloInstruction* instruction) { VLOG(2) << "Removing instruction " << instruction->name() << " from computation " << name(); - TF_RET_CHECK(IsRemovable(instruction->opcode())); + TF_RET_CHECK(IsRemovable(instruction)); TF_RET_CHECK(root_instruction() != instruction) << "cannot remove root instruction " << instruction->name(); TF_RET_CHECK(instruction->user_count() == 0) @@ -295,21 +356,27 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString() const { +string HloComputation::ToString(int nested_level) const { std::ostringstream s; + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) << " { \n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << " " << instruction->ToString() << "\n"; if (instruction->opcode() == HloOpcode::kFusion) { - tensorflow::gtl::FlatSet added_instructions; - auto fused_instructions = InstructionPostOrderer::GetOrder( - instruction->fused_expression_root(), &added_instructions); - for (const auto& fused_instruction : fused_instructions) { - s << " " << fused_instruction->ToString() << "\n"; - } + s << instruction->fused_instructions_computation()->ToString( + nested_level + 1) + << "\n"; } } + for (int i = 0; i < nested_level; i++) { + s << " "; + } s << "}"; return s.str(); } @@ -583,4 +650,44 @@ Status HloComputation::Accept( return this->Accept(&visitor); } +std::unique_ptr HloComputation::Clone(const string& suffix) { + VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; + auto postorder = MakeInstructionPostOrder(); + std::unordered_map clone_map; + std::vector> instructions; + std::unique_ptr new_instr = nullptr; + for (auto instr : postorder) { + std::vector new_operands; + for (auto operand : instr->operands()) { + HloInstruction* new_operand = FindOrDie(clone_map, operand); + CHECK(new_operand != nullptr); + new_operands.push_back(new_operand); + } + + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands); + InsertOrDie(&clone_map, instr, new_instr.get()); + instructions.push_back(std::move(new_instr)); + } + Builder builder(name() + suffix); + for (auto& instr : instructions) { + builder.AddInstruction(std::move(instr)); + } + auto result = builder.Build( + /*root_instruction=*/FindOrDie(clone_map, root_instruction())); + + // Clone control dependencies. + for (auto instr : postorder) { + HloInstruction* new_instr = FindOrDie(clone_map, instr); + for (auto successor : instr->control_successors()) { + TF_CHECK_OK( + new_instr->AddControlDependencyTo(FindOrDie(clone_map, successor))); + } + } + return result; +} + +void HloComputation::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index ef3cba6fa08da81d35a3e8b06c8028cba0de8111..62e00a24fbb523e1e30f08141f9e026407a2015d 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -54,8 +54,10 @@ class HloComputation { // Builder class for HloComputation. class Builder { public: - explicit Builder(const string& name) - : name_(name), last_added_instruction_(nullptr) {} + explicit Builder(const string& name, bool is_fusion_computation = false) + : name_(name), + last_added_instruction_(nullptr), + is_fusion_computation_(is_fusion_computation) {} // Build and return an HloComputation. The parameter root_instruction // specifies the already-added instruction to use as the root. If @@ -74,6 +76,7 @@ class HloComputation { private: const string name_; HloInstruction* last_added_instruction_; + bool is_fusion_computation_; std::vector> instructions_; }; @@ -81,6 +84,16 @@ class HloComputation { // the instruction. HloInstruction* AddInstruction(std::unique_ptr instruction); + // Remove the param_no'th parameter from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveParameter(int64 param_no); + + // Add new parameter instruction to the computation. + // This should be a new parameter. Instruction will be appended to parameters + // and inserted to the instruction list. + HloInstruction* AddParameter(std::unique_ptr instruction); + // Remove an instruction from the computation. The instruction must have no // users. Instruction is deallocated with this call. Status RemoveInstruction(HloInstruction* instruction); @@ -121,8 +134,12 @@ class HloComputation { const string& name() const { return name_; } + // Use the given NameUniquer to select a unique name for the computation based + // on the computation's existing name. + void UniquifyName(NameUniquer* name_uniquer); + // Return a string representation of the computation. - string ToString() const; + string ToString(int nested_level = 0) const; const std::list>& instructions() const { return instructions_; @@ -219,17 +236,24 @@ class HloComputation { // Same as Accept() above, but the visitor is given as a function. Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; - // Returns true if instructions of the given opcode can be removed from the + // Returns a deep copy of this computation including all instructions. + std::unique_ptr Clone(const string& suffix = "clone"); + + // Returns true if the given instruction can be removed from the // computation. Instructions such as parameters and send/receive instructions // cannot be removed without violating invariants of the HLO computation or - // module. - static bool IsRemovable(const HloOpcode& opcode); + // module with the exception of fusion computation. A parameter instruction + // is removable for a fusion computation. + bool IsRemovable(const HloInstruction* instruction); + + // Returns if this computation is a fusion computation. + bool IsFusionComputation() const { return is_fusion_computation_; } private: explicit HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction); + HloInstruction* root_instruction, bool is_fusion_computation = false); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -237,10 +261,6 @@ class HloComputation { // Helper for setting the parent of instructions that are added to this // computation. - // - // Because we clone HLO instructions without knowing what computation they're - // destined to be added to, this is required to appropriate set the parent on - // fused instruction sequences. void Reparent(HloInstruction* instruction); // Fuses HLOs in instructions_to_fuse into fusion_instruction. @@ -257,9 +277,12 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; - const string name_; + string name_; HloInstruction* root_instruction_; + // A tag shows if this is a fusion computation. + bool is_fusion_computation_; + // Module containing this computation. HloModule* parent_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 12a568339627bea412dbbf478474df0f7e8190a6..3812653fe3f02f176e556e4bfb3abc6056c0cd01 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -20,15 +20,22 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + class HloComputationTest : public HloTestBase { protected: HloComputationTest() {} @@ -67,8 +74,8 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { auto negate_computation = CreateNegateComputation(); auto map_computation = CreateMapComputation(negate_computation.get()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); - EXPECT_EQ(map_computation->MakeEmbeddedComputationsList().front(), - negate_computation.get()); + EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(), + ElementsAre(negate_computation.get())); } TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { @@ -93,10 +100,10 @@ TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // GetEmbeddedComputations returns a post order of the embedded computations, // so the negate computation must come first. EXPECT_EQ(negate_computation.get(), *embedded_computations.begin()); - EXPECT_MATCH(testing::ListToVec(embedded_computations), - testing::UnorderedMatcher( - negate_computation.get(), map1_computation.get(), - map2_computation.get())); + EXPECT_THAT( + embedded_computations, + UnorderedElementsAre(negate_computation.get(), map1_computation.get(), + map2_computation.get())); } TEST_F(HloComputationTest, PostOrderSingleton) { @@ -106,7 +113,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto computation = builder.Build(); - EXPECT_EQ(computation->MakeInstructionPostOrder().front(), constant); + EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } TEST_F(HloComputationTest, PostOrderSimple) { @@ -121,10 +128,8 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); auto computation = builder.Build(); - EXPECT_MATCH( - testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::OrderedMatcher(constant, negate1, negate2)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + ElementsAre(constant, negate1, negate2)); } TEST_F(HloComputationTest, PostOrderTrace) { @@ -141,10 +146,8 @@ TEST_F(HloComputationTest, PostOrderTrace) { auto computation = builder.Build(); // Trace instructions should be at the end of the sort. - EXPECT_MATCH(testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::OrderedMatcher(constant, negate1, - negate2, trace)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + ElementsAre(constant, negate1, negate2, trace)); } TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { @@ -161,10 +164,8 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); auto computation = builder.Build(); - EXPECT_MATCH(testing::ListToVec( - computation->MakeInstructionPostOrder()), - testing::UnorderedMatcher( - constant1, constant2, constant3, constant4)); + EXPECT_THAT(computation->MakeInstructionPostOrder(), + UnorderedElementsAre(constant1, constant2, constant3, constant4)); } TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { @@ -187,9 +188,8 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); - EXPECT_MATCH(testing::ListToVec(post_order), - testing::UnorderedMatcher( - constant1, constant2, constant3, add1, add2, add3)); + EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3, + add1, add2, add3)); } TEST_F(HloComputationTest, VisitWithMultipleRoots) { @@ -253,8 +253,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); - EXPECT_EQ(HloOpcode::kCopy, copy->opcode()); - EXPECT_EQ(constant, copy->operand(0)); + EXPECT_THAT(copy, op::Copy(constant)); } TEST_F(HloComputationTest, DeepCopyTuple) { @@ -271,18 +270,10 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); - EXPECT_EQ(HloOpcode::kTuple, tuple_copy->opcode()); - EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(0)->opcode()); - const HloInstruction* gte0 = tuple_copy->operand(0)->operand(0); - EXPECT_EQ(HloOpcode::kGetTupleElement, gte0->opcode()); - EXPECT_EQ(0, gte0->tuple_index()); - EXPECT_EQ(tuple, gte0->operand(0)); - - EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(1)->opcode()); - const HloInstruction* gte1 = tuple_copy->operand(1)->operand(0); - EXPECT_EQ(HloOpcode::kGetTupleElement, gte1->opcode()); - EXPECT_EQ(1, gte1->tuple_index()); - EXPECT_EQ(tuple, gte1->operand(0)); + EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), + op::Copy(op::GetTupleElement(tuple)))); + EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index()); + EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index()); } TEST_F(HloComputationTest, CycleDetection) { @@ -302,8 +293,8 @@ TEST_F(HloComputationTest, CycleDetection) { const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); - ASSERT_MATCH(visit_status.error_message(), - testing::ContainsRegex("cycle is detecte")); + ASSERT_THAT(visit_status.error_message(), + ::testing::ContainsRegex("cycle is detecte")); } TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { @@ -322,14 +313,45 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { auto computation = builder.Build(); EXPECT_EQ(4, computation->instruction_count()); + EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); EXPECT_EQ(negate, computation->root_instruction()); ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add)); EXPECT_EQ(2, computation->instruction_count()); + EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); EXPECT_EQ(negate, computation->root_instruction()); } +TEST_F(HloComputationTest, CloneWithControlDependency) { + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + r0f32_, HloOpcode::kAdd, constant1, constant2)); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + auto computation = builder.Build(/*root_instruction=*/add); + + TF_CHECK_OK(negate->AddControlDependencyTo(add)); + + auto clone = computation->Clone(); + + auto cloned_add = clone->root_instruction(); + EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd); + + auto predecessors = cloned_add->control_predecessors(); + EXPECT_EQ(1, predecessors.size()); + EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode()); + auto successors = predecessors[0]->control_successors(); + EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 9a5345dc13d6db42553e9c343f7c81cd0e6c9d0e..cb0a99d773c57ba9a2fedc2842fe17cd5fe3571e 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -15,16 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" -#include -#include #include -#include #include #include #include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -34,52 +32,222 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace xla { +namespace { + +template +static std::unique_ptr ConvertIfTypesMatch( + const Literal& src_literal) { + CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); + return LiteralUtil::Convert< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); +} + +template +static std::unique_ptr ConvertIfDestTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (primitive_dest_type) { +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal); + CONVERT_IF_TYPES_MATCH(PRED) + CONVERT_IF_TYPES_MATCH(S8) + CONVERT_IF_TYPES_MATCH(S32) + CONVERT_IF_TYPES_MATCH(S64) + CONVERT_IF_TYPES_MATCH(U8) + CONVERT_IF_TYPES_MATCH(U32) + CONVERT_IF_TYPES_MATCH(U64) + CONVERT_IF_TYPES_MATCH(F32) + CONVERT_IF_TYPES_MATCH(F64) +#undef CONVERT_IF_TYPES_MATCH + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +static std::unique_ptr ConvertIfSrcTypeMatches( + const Literal& src_literal, PrimitiveType primitive_dest_type) { + switch (src_literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + CONVERT_IF_DEST_TYPE_MATCHES(PRED) + CONVERT_IF_DEST_TYPE_MATCHES(S8) + CONVERT_IF_DEST_TYPE_MATCHES(S32) + CONVERT_IF_DEST_TYPE_MATCHES(S64) + CONVERT_IF_DEST_TYPE_MATCHES(U8) + CONVERT_IF_DEST_TYPE_MATCHES(U32) + CONVERT_IF_DEST_TYPE_MATCHES(U64) + CONVERT_IF_DEST_TYPE_MATCHES(F32) + CONVERT_IF_DEST_TYPE_MATCHES(F64) +#undef CONVERT_IF_DEST_TYPE_MATCHES + // Other types are not yet supported. + default: + LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type " + << PrimitiveType_Name(src_literal.shape().element_type()); + } +} + +} // namespace + +// ConstantFolderVisitor traverses the HLO computation and reduces certain +// constant graph sections, to literals. +class ConstantFolderVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override; + + Status HandleConvert(HloInstruction* convert, + HloInstruction* operand) override; + + Status HandleReshape(HloInstruction* reshape) override; + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; + + Status HandleTranspose(HloInstruction* transpose) override; + + // Returns whether a constant folding operation has occurred. + const bool changed() const { return changed_; } + + // Runs the visitor on a computation and returns whether any changes were + // performed. + static StatusOr Run(HloComputation* computation); + + private: + ConstantFolderVisitor() = default; + + // Replaces the existing HLO instruction old_instruction, with a literal, + // and marks the optimizer status as changed. + // Returns the Status representing the result of the replace operation. + Status ReplaceWithConstant(HloInstruction* old_instruction, + std::unique_ptr literal) { + TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( + old_instruction, HloInstruction::CreateConstant(std::move(literal)))); + changed_ = true; + return Status::OK(); + } + + // Whether any constant folding operations have occurred. + bool changed_ = false; +}; + +StatusOr ConstantFolderVisitor::Run(HloComputation* computation) { + ConstantFolderVisitor visitor; + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + return visitor.changed(); +} StatusOr HloConstantFolding::Run(HloModule* module) { + XLA_VLOG_LINES(2, + "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - // Skip dead code. - if (instruction->user_count() == 0 && - computation->root_instruction() != instruction) { - continue; - } - // Depending on the opcode, choose how to handle constant operands. - // - // TODO(b/35975797): Fold constant computations for more than reshapes and - // transposes. - switch (instruction->opcode()) { - case HloOpcode::kReshape: { - if (instruction->operand(0)->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN( - auto reshaped_literal, - LiteralUtil::Reshape( - instruction->operand(0)->literal(), - AsInt64Slice(instruction->shape().dimensions()))); - TF_CHECK_OK(computation->ReplaceWithNewInstruction( - instruction, - HloInstruction::CreateConstant(std::move(reshaped_literal)))); - changed = true; - } - break; - } - case HloOpcode::kTranspose: { - if (instruction->operand(0)->opcode() == HloOpcode::kConstant) { - auto transposed_literal = LiteralUtil::Transpose( - instruction->operand(0)->literal(), instruction->dimensions()); - TF_CHECK_OK(computation->ReplaceWithNewInstruction( - instruction, - HloInstruction::CreateConstant(std::move(transposed_literal)))); - changed = true; - } - break; - } - default: - break; + for (auto& comp : module->computations()) { + TF_ASSIGN_OR_RETURN(bool result, ConstantFolderVisitor::Run(comp.get())); + changed = changed || result; + } + XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +Status ConstantFolderVisitor::HandleReshape(HloInstruction* reshape) { + if (reshape->operand(0)->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN( + auto reshaped_literal, + LiteralUtil::Reshape(reshape->operand(0)->literal(), + AsInt64Slice(reshape->shape().dimensions()))); + return ReplaceWithConstant(reshape, std::move(reshaped_literal)); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleTranspose(HloInstruction* transpose) { + if (transpose->operand(0)->opcode() == HloOpcode::kConstant) { + auto transposed_literal = LiteralUtil::Transpose( + transpose->operand(0)->literal(), transpose->dimensions()); + return ReplaceWithConstant(transpose, std::move(transposed_literal)); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) { + if (operands[0]->opcode() == HloOpcode::kConstant) { + // If all the operands of a concatenate are constant, fold them into a + // single constant tensor. + // The result concatenate dimension is going to be the sum of all the + // concatenate dimensions of the arrays taking part of the operation. + int64 concat_dim = concatenate->dimensions()[0]; + const Shape& reference_shape = operands[0]->shape(); + CHECK(!ShapeUtil::IsTuple(reference_shape)); + int64 rank = ShapeUtil::Rank(reference_shape); + std::vector concat_dimensions(reference_shape.dimensions().begin(), + reference_shape.dimensions().end()); + if (concat_dim < 0) { + concat_dim += rank; + } + for (int64 i = 1; i < operands.size(); ++i) { + const Shape& operand_shape = operands[i]->shape(); + CHECK(!ShapeUtil::IsTuple(operand_shape)); + if (operands[i]->opcode() != HloOpcode::kConstant) { + return Status::OK(); } + // Accumulate the concat dimension from all tensors taking part to the + // operation. + concat_dimensions[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); + } + + auto literal = LiteralUtil::CreateFromDimensions( + reference_shape.element_type(), concat_dimensions); + std::vector source_indices(rank, 0); + std::vector dest_indices(concat_dimensions.size(), 0); + for (auto operand : operands) { + const Shape& operand_shape = operand->shape(); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + operand->literal(), source_indices, literal.get(), dest_indices, + AsInt64Slice(operand_shape.dimensions()))); + dest_indices[concat_dim] += + ShapeUtil::GetDimension(operand_shape, concat_dim); } + return ReplaceWithConstant(concatenate, std::move(literal)); } - return changed; + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleSlice(HloInstruction* slice, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kConstant) { + const Shape& shape = slice->shape(); + auto literal = LiteralUtil::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + std::vector dest_indices(slice->slice_starts().size(), 0); + TF_RETURN_IF_ERROR(LiteralUtil::Copy( + operand->literal(), slice->slice_starts(), literal.get(), dest_indices, + AsInt64Slice(shape.dimensions()))); + TF_RETURN_IF_ERROR(ReplaceWithConstant(slice, std::move(literal))); + } + return Status::OK(); +} + +Status ConstantFolderVisitor::HandleConvert(HloInstruction* convert, + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kConstant) { + const Literal& src_literal = operand->literal(); + std::unique_ptr new_constant = + ConvertIfSrcTypeMatches(src_literal, convert->shape().element_type()); + return ReplaceWithConstant(convert, std::move(new_constant)); + } + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 514bb8164c1e1fa10a36ceeeac63dc946de2ab5a..331480bd029727fa15476cb9ced2e7b7afd170f3 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -21,16 +21,14 @@ limitations under the License. namespace xla { -// A pass which performs constant folding in order to avoid unecessary +// A pass which performs constant folding in order to avoid unnecessary // computation on constants. class HloConstantFolding : public HloPassInterface { public: - explicit HloConstantFolding() {} - ~HloConstantFolding() override {} tensorflow::StringPiece name() const override { return "constant_folding"; } - // Run ConstantFolding on the given module. Returns whether the module was - // changed (common subexpressions were found and eliminated). + // Run constant folding operations on the given module. Returns whether the + // module was changed (constant expressions folded). StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a56225da156dfc0a44b6a4b99191a3c7e706561f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" + +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using HloConstantFoldingTest = HloTestBase; + +TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42); +} + +TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ(LiteralUtil::GetFirstElement( + computation->root_instruction()->literal()), + 42.0f); +} + +TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { + HloComputation::Builder builder(TestName()); + HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({42.0f, 19.0f}))); + builder.AddInstruction( + HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Convert(input)); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {0}), + 42); + EXPECT_EQ( + LiteralUtil::Get(computation->root_instruction()->literal(), {1}), + 19); +} + +TEST_F(HloConstantFoldingTest, Concatenate) { + const struct TestConfig { + int concat_dimension; + tensorflow::gtl::ArraySlice dimensions; + tensorflow::gtl::ArraySlice concat_sizes; + } test_configs[] = { + {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, + {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, + }; + + for (auto& test_config : test_configs) { + HloComputation::Builder builder(TestName()); + std::vector dimensions(test_config.dimensions.begin(), + test_config.dimensions.end()); + int64 concat_size = 0; + std::vector operands; + for (auto csize : test_config.concat_sizes) { + dimensions[test_config.concat_dimension] = csize; + concat_size += csize; + auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + HloInstruction* insn = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + operands.push_back(insn); + } + dimensions[test_config.concat_dimension] = concat_size; + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + builder.AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, test_config.concat_dimension)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); + } +} + +TEST_F(HloConstantFoldingTest, Slice) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + const int64 slice_start[] = {4, 2, 3, 1, 5}; + const int64 slice_limits[] = {10, 8, 6, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4}); + builder.AddInstruction(HloInstruction::CreateSlice( + shape, literal_instruction, slice_start, slice_limits)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); +} + +TEST_F(HloConstantFoldingTest, TransposeConstantFold) { + HloComputation::Builder builder(TestName()); + const int64 dimensions[] = {11, 8, 7, 5, 9}; + TF_ASSIGN_OR_ASSERT_OK(auto literal, + LiteralTestUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); + auto literal_clone = LiteralUtil::CloneToUnique(*literal); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); + const int64 permutation[] = {1, 2, 0, 4, 3}; + builder.AddInstruction( + HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + HloConstantFolding const_folder; + TF_ASSIGN_OR_ASSERT_OK(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape)); + + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + bool matched = true; + LiteralUtil::EachCell( + root->literal(), + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + matched = matched && (value == LiteralUtil::Get(*literal_clone, + rindexes)); + }); + EXPECT_TRUE(matched); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 8fe1897e75cd0b5f013877b718735d117a5ee06b..38cc74b0f1e640d4e72188416258d9b262053152 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -136,9 +136,9 @@ Status HloCostAnalysis::HandleSlice(HloInstruction* slice, return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) { +Status HloCostAnalysis::HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) { return Status::OK(); } @@ -357,7 +357,9 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { // Compute the cost of the fused expression. HloInstruction* fused_expression_root = fusion->fused_expression_root(); - HloCostAnalysis visitor(shape_size_); + // Don't compute sizes inside of fused ops. We don't use the size here and the + // operations inside might not have a layout. + HloCostAnalysis visitor([](const Shape&) { return 0; }); TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); // Attribute the cost of the fused expression to the fusion node. @@ -366,11 +368,9 @@ Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { return Status::OK(); } -Status HloCostAnalysis::HandleCall( - HloInstruction* call, tensorflow::gtl::ArraySlice operands, - HloComputation* computation) { +Status HloCostAnalysis::HandleCall(HloInstruction* call) { HloCostAnalysis computation_visitor(shape_size_); - TF_RETURN_IF_ERROR(computation->Accept(&computation_visitor)); + TF_RETURN_IF_ERROR(call->to_apply()->Accept(&computation_visitor)); current_flop_count_ = computation_visitor.flop_count(); current_transcendental_count_ = computation_visitor.transcendental_count(); @@ -394,18 +394,15 @@ Status HloCostAnalysis::HandleSort(HloInstruction* sort, return Status::OK(); } -Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while, - HloInstruction* init, - HloComputation* condition, - HloComputation* body) { +Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { // Since the number of iterations of the while node is not statically // determined, we cannot precisely compute the cost of a while node. For now // compute the cost of a single iteration. // TODO(b/26346211): Improve the cost analysis for while node. HloCostAnalysis body_visitor(shape_size_); - TF_RETURN_IF_ERROR(body->Accept(&body_visitor)); + TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&body_visitor)); HloCostAnalysis condition_visitor(shape_size_); - TF_RETURN_IF_ERROR(condition->Accept(&condition_visitor)); + TF_RETURN_IF_ERROR(xla_while->while_condition()->Accept(&condition_visitor)); current_flop_count_ = body_visitor.flop_count() + condition_visitor.flop_count(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index e6f059f53379df51c9f0b99e0e01f34f1aebb52a..b2c40f75ca4e833f1f5529977564b0e3a7ca25b1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -84,16 +84,14 @@ class HloCostAnalysis : public DfsHloVisitor { tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call, - tensorflow::gtl::ArraySlice operands, - HloComputation* computation) override; + Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) override; Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - Status HandleDynamicSlice( - HloInstruction* slice, - tensorflow::gtl::ArraySlice operands) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) override; Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, HloInstruction* operand, HloInstruction* update, @@ -115,8 +113,7 @@ class HloCostAnalysis : public DfsHloVisitor { Status HandlePad(HloInstruction* pad) override; Status HandleReshape(HloInstruction* reshape) override; Status HandleTranspose(HloInstruction* transpose) override; - Status HandleWhile(HloInstruction* xla_while, HloInstruction* init, - HloComputation* condition, HloComputation* body) override; + Status HandleWhile(HloInstruction* xla_while) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -136,7 +133,7 @@ class HloCostAnalysis : public DfsHloVisitor { int64 bytes_accessed() const { return bytes_accessed_; } private: - // An FMA counts as two floating point operations in these analyses. + // An FMA counts as two floating point operations in these analyzes. static constexpr int64 kFmaFlops = 2; // Utility function to handle all element-wise operations. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 9f1c91d41c6bbe8f4cd61120ab0e260097214187..f71ffeb887a6a066a1516b941ca5bf237efc2890 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -126,8 +126,10 @@ class HloCostAnalysisTest : public ::testing::Test { auto user_computation = user_computation_status.ConsumeValueOrDie(); VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie()); + return std::move(computation_tracker_ + .BuildHloModule(versioned_handle, + /*config=*/nullptr) + .ValueOrDie()); } Client* client_; @@ -375,6 +377,33 @@ TEST_F(FusionCostAnalysis, LoopFusion) { EXPECT_EQ(fusion_analysis.transcendental_count(), 4); } +TEST_F(FusionCostAnalysis, NoLayout) { + Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); + // Instructions within a fused op may have no layout. + Shape shape_without_layout = shape_with_layout; + shape_without_layout.clear_layout(); + + auto c1 = HloInstruction::CreateConstant( + LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); + auto c2 = + HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + + auto broadcast = + HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); + auto add = HloInstruction::CreateBinary(shape_with_layout, HloOpcode::kAdd, + c1.get(), broadcast.get()); + + auto fusion = HloInstruction::CreateFusion( + shape_with_layout, HloInstruction::FusionKind::kLoop, add.get()); + fusion->FuseInstruction(broadcast.get()); + + HloCostAnalysis fusion_analysis(ShapeSize); + ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); + + EXPECT_EQ(fusion_analysis.flop_count(), 120); + EXPECT_EQ(fusion_analysis.transcendental_count(), 0); +} + TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index ec8161f55fd56c95bb088a0c539255aed2fe6993..9444382b5270b0f76fa33b598297d24572e5b2c9 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -36,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -88,13 +91,15 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_NE(add->operand(0), add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); - EXPECT_EQ(add->operand(0), add->operand(1)); + auto first_operand = add->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); + EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -118,15 +123,13 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(constant1, add->operand(0)); - EXPECT_EQ(constant2, add->operand(1)); + EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); @@ -185,16 +188,18 @@ TEST_F(HloCseTest, NonscalarConstants) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); + EXPECT_THAT(tuple, + op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(uncommon_constant, tuple->operand(2)); - EXPECT_TRUE(tuple->operand(0) == common_constant1 || - tuple->operand(0) == common_constant2); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, + ::testing::AnyOf(common_constant1, common_constant2)); + EXPECT_THAT(tuple, + op::Tuple(first_operand, first_operand, uncommon_constant)); } TEST_F(HloCseTest, IdenticalInstructions) { @@ -215,16 +220,15 @@ TEST_F(HloCseTest, IdenticalInstructions) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); - EXPECT_NE(tuple->operand(1), tuple->operand(2)); - EXPECT_NE(tuple->operand(0), tuple->operand(2)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(tuple->operand(1), tuple->operand(2)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { @@ -249,13 +253,13 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); EXPECT_FALSE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); } TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { @@ -280,13 +284,15 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); + auto first_operand = tuple->operand(0); + EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2)); + EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand)); } TEST_F(HloCseTest, IdenticalExpressions) { @@ -328,14 +334,15 @@ TEST_F(HloCseTest, IdenticalExpressions) { auto computation = module.AddEntryComputation(builder.Build()); EXPECT_EQ(8, computation->instruction_count()); - EXPECT_NE(tuple->operand(0), tuple->operand(1)); + EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(&module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); - EXPECT_EQ(tuple->operand(0), tuple->operand(1)); - EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode()); + auto operand = tuple->operand(0); + EXPECT_THAT(tuple, op::Tuple(operand, operand)); + EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp())); } TEST_F(HloCseTest, DoNotCombineRng) { @@ -351,12 +358,16 @@ TEST_F(HloCseTest, DoNotCombineRng) { auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); + builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, rng1, rng2)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); + uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); @@ -364,11 +375,8 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(rng1, rng2)); } // TODO(b/28245743): Handle impure functions correctly in CSE. @@ -412,16 +420,17 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { } EXPECT_EQ(4, computation->instruction_count()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(op::Map(), op::Map())); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kAdd); - EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap); - EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap); - EXPECT_NE(root->operand(0), root->operand(1)); + root = computation->root_instruction(); + auto operand = root->operand(0)->operand(0); + EXPECT_THAT(operand, op::Map()); + EXPECT_THAT(root, op::Add(operand, operand)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fdfbbf8baf65884fcb1eed846e6ce3eda07bc45d..3755b9e4c005c5e50b149d8dc8c51363eb111868 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -52,7 +52,7 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto& instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction.get()) == 0 && - HloComputation::IsRemovable(instruction->opcode())) { + computation->IsRemovable(instruction.get())) { dead_roots.push_back(instruction.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index dcd9e00c56c76046e6c1de75558637b7e941e57e..4191eaaad06da5baf01cd74e6a52d6aacf396cd6 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -93,5 +94,65 @@ TEST_F(HloDceTest, DeadParameters) { EXPECT_EQ(0, dead_param1->user_count()); } +TEST_F(HloDceTest, ControlDependencies) { + // Verify that instructions with control dependencies are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + + // Create two dead instructions: a negate and an add. + auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( + constant1->shape(), HloOpcode::kNegate, constant1)); + auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + // Create the same two instructions again, but these will have a control + // dependency added. + auto dead_negate_with_control_dep = + builder.AddInstruction(HloInstruction::CreateUnary( + constant1->shape(), HloOpcode::kNegate, constant1)); + auto dead_add_with_control_dep = + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + // Create a root so the previously added instruction is dead. + builder.AddInstruction(HloInstruction::CreateBinary( + constant1->shape(), HloOpcode::kAdd, constant1, constant2)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + // Add a control dependency between two instructions. + TF_ASSERT_OK(dead_negate_with_control_dep->AddControlDependencyTo( + dead_add_with_control_dep)); + + // Returns whether the given instruction exists in the test computation. + auto has_instruction = [computation](const HloInstruction* instruction) { + for (auto& inst : computation->instructions()) { + if (inst.get() == instruction) { + return true; + } + } + return false; + }; + + EXPECT_EQ(7, computation->instruction_count()); + EXPECT_TRUE(has_instruction(dead_negate)); + EXPECT_TRUE(has_instruction(dead_add)); + EXPECT_TRUE(has_instruction(dead_negate_with_control_dep)); + EXPECT_TRUE(has_instruction(dead_add_with_control_dep)); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(5, computation->instruction_count()); + EXPECT_FALSE(has_instruction(dead_negate)); + EXPECT_FALSE(has_instruction(dead_add)); + EXPECT_TRUE(has_instruction(dead_negate_with_control_dep)); + EXPECT_TRUE(has_instruction(dead_add_with_control_dep)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc new file mode 100644 index 0000000000000000000000000000000000000000..e0447d69aa2229e2cb391aac8b2afa8fde6145c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -0,0 +1,557 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +template +class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + }; + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive types. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + }; + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + }; + + Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { + return HandleAbs(abs, operand); + }; + + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + }; + + Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], + ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { + return elem_operand; + })); + return Status::OK(); + }; + + Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + }; + + Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + }; + + Status HandleIsFinite(HloInstruction* is_finite, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[is_finite], + ElementWiseUnaryOp(is_finite, [](ReturnT elem_operand) { + return std::isfinite(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLog(HloInstruction* log, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + }; + + Status HandleLogicalNot(HloInstruction* logical_not, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_not], + ElementWiseUnaryOp(logical_not, + [](ReturnT elem_operand) { return !elem_operand; })); + return Status::OK(); + }; + + Status HandleNegate(HloInstruction* negate, + HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { + return -elem_operand; + })); + return Status::OK(); + }; + + Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + return (ReturnT(0) < elem_operand) - + (elem_operand < ReturnT(0)); + })); + return Status::OK(); + }; + + Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + }; + + Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + }; + + Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + }; + + Status HandleAdd(HloInstruction* add, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + }; + + Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + }; + + Status HandleCompare(HloInstruction* compare, HloOpcode opcode, + HloInstruction* lhs, HloInstruction* rhs) override { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el != rhs_el; + }; + break; + case HloOpcode::kGe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el >= rhs_el; + }; + break; + case HloOpcode::kGt: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el > rhs_el; + }; + break; + case HloOpcode::kLe: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el <= rhs_el; + }; + break; + case HloOpcode::kLt: + compare_op = [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el < rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Compare operation with mismatched dimensions, likely due to " + "broadcasting is unsupported."); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = LiteralUtil::CreateFromShape(compare->shape()); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + compare_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + parent_->evaluated_[compare] = std::move(result); + + return Status::OK(); + }; + + Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + }; + + Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandlePower(HloInstruction* power, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { + return std::remainder(lhs_el, rhs_el); + })); + return Status::OK(); + }; + + Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_and], + ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + }; + + Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, + HloInstruction* rhs) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logical_or], + ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + }; + + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override { + std::function clamp_op = + [](ReturnT low, ReturnT high, ReturnT value) { + return std::max(low, std::min(value, high)); + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], + ElementWiseTernaryOp(clamp, std::move(clamp_op))); + return Status::OK(); + }; + + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override { + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementWiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + }; + + Status Preprocess(HloInstruction* hlo) override { + VLOG(2) << hlo->ToString(); + return Status::OK(); + }; + + private: + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + + auto result = LiteralUtil::CreateFromShape(shape); + + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + unary_op(LiteralUtil::Get(operand_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = LiteralUtil::CreateFromShape(shape); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + binary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + template + StatusOr> ElementWiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = LiteralUtil::CreateFromShape(shape); + std::vector multi_index(ShapeUtil::Rank(result->shape()), 0); + do { + LiteralUtil::Set( + result.get(), multi_index, + ternary_op(LiteralUtil::Get(lhs_literal, multi_index), + LiteralUtil::Get(rhs_literal, multi_index), + LiteralUtil::Get(ehs_literal, multi_index))); + } while (IndexUtil::BumpIndices(result->shape(), &multi_index)); + + return std::move(result); + }; + + HloEvaluator* parent_; +}; + +HloEvaluator::HloEvaluator() { + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[U16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: U16."); + }); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[S16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: S16."); + }); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = MakeUnique([](HloInstruction*) { + return Unimplemented("unhandled primitive type: F16."); + }); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); +} + +StatusOr> HloEvaluator::Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice args) { + arg_literals_ = args; + evaluated_.clear(); + + TF_RETURN_IF_ERROR(computation->Accept(this)); + return std::move(FindOrDie(evaluated_, computation->root_instruction())); +} + +StatusOr> HloEvaluator::Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice operands) { + DCHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); + Shape shape = instruction->shape(); + TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); + + arg_literals_ = operands; + evaluated_.clear(); + + // Evaluate operands of Parameter type against the input literals which + // caches the evaluated literal results. + for (const auto operand : instruction->operands()) { + if (operand->opcode() == HloOpcode::kParameter) { + const Literal* input_literal = arg_literals_[operand->parameter_number()]; + VLOG(2) << "Parameter operand evaluated to: " + << LiteralUtil::ToString(*input_literal); + TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); + + evaluated_[operand] = MakeUnique(*input_literal); + } else if (operand->opcode() == HloOpcode::kConstant) { + evaluated_[operand] = MakeUnique(operand->literal()); + } + } + + TF_RETURN_IF_ERROR(instruction->Visit(this)); + return std::move(FindOrDie(evaluated_, instruction)); +} + +Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + VLOG(2) << "HandleParameter: " << parameter->ToString(); + const Literal* input_literal = arg_literals_[parameter->parameter_number()]; + VLOG(2) << "Parameter evaluated to: " + << LiteralUtil::ToString(*input_literal); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + + evaluated_[parameter] = MakeUnique(*input_literal); + return Status::OK(); +} + +Status HloEvaluator::HandleConstant(HloInstruction* constant, + const Literal& literal) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + DCHECK(ShapeUtil::Equal(constant->shape(), literal.shape())); + + evaluated_[constant] = MakeUnique(literal); + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h new file mode 100644 index 0000000000000000000000000000000000000000..040fd3d73c8e5887f4b5d2952a088687b099c560 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ + +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +// Responsible for evaluating HLO and obtain literal as the evaluation results. +// +// This class is not thread-safe. +class HloEvaluator : public DfsHloVisitorWithDefault { + public: + HloEvaluator(); + // Evaluates a HLO computation and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: argument literals are corresponds to the input computation's + // parameters in their post-ordering. For e.g., consider the following graph: + // + // * + // / \ + // + Parameter1 + // / \ + // / \ + // Parameter0 Constant + // + // The input literals array will have its first literal map to Parameter0 and + // the second map to Parameter1. + StatusOr> Evaluate( + HloComputation* computation, + tensorflow::gtl::ArraySlice arg_literals); + + // Evaluates a single HLO instruction and an array of pointers to literals. + // Return the evaluated result as literal if successful. + // Precondition: + // 1. argument literals are corresponds to the input instruction's + // parameters in their post-orderring. + // 2. the instruction's operands must be of either Parameter or Constant type. + // TODO(b/35950897): implement more ops other than element-wise ops. + StatusOr> Evaluate( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); + + protected: + // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting + // literal type of each evaluated Handle* method of a TypedVisitor. One + // exception to this is HandleCompare, where the resulting literal type is + // always boolean. + // Note the forward declaration here is necessary to enable TypedVisitor to + // access parent members. + template + class TypedVisitor; + + // Wraps around instruction handling to infer types before dispatching to + // the corresponding typed Visitor. + Status DefaultAction(HloInstruction* hlo) override { + return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + } + + Status HandleParameter(HloInstruction* parameter) override; + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override; + + private: + // Returns the already-evaluated literal result for the instruction. + // Crash with log if the given instruction has not been evaluated previously. + const Literal& GetEvaluatedLiteralFor(const HloInstruction* hlo) { + auto it = evaluated_.find(hlo); + CHECK(it != evaluated_.end()) + << "could not find evaluated value for: " << hlo->ToString(); + return *(it->second); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + + // Tracks the HLO instruction and its evaluated literal result. + // TODO(b/35950897): have better memory management here to free instructions + // that are no longer a parent for any other subsequent instruction in + // post-orderring. + tensorflow::gtl::FlatMap> + evaluated_; + + // Stores input literals, assuming they are in post-order. Literals are not + // owned by this class, and they must outlive the lifetime of the instance of + // this class. + tensorflow::gtl::ArraySlice arg_literals_; + + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..443e5ad4f4290ff10b867887ac5ed359a0c8f73a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloEvaluatorTest : public ::testing::Test { + protected: + HloEvaluatorTest() { evaluator_ = MakeUnique(); } + + std::unique_ptr evaluator_; +}; + +// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp +// with 3 operands. +TEST_F(HloEvaluatorTest, DoesClamp) { + auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + + Shape shape = low->shape(); + auto c1 = HloInstruction::CreateConstant(std::move(low)); + auto c2 = HloInstruction::CreateConstant(std::move(high)); + auto c3 = HloInstruction::CreateConstant(std::move(value)); + auto instruction = HloInstruction::CreateTernary( + shape, HloOpcode::kClamp, c1.get(), c2.get(), c3.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs select +// with 3 operands. +TEST_F(HloEvaluatorTest, DoesSelect) { + auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); + auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + + Shape shape = on_true->shape(); + auto c1 = HloInstruction::CreateConstant(std::move(pred)); + auto c2 = HloInstruction::CreateConstant(std::move(on_true)); + auto c3 = HloInstruction::CreateConstant(std::move(on_false)); + auto instruction = HloInstruction::CreateTernary( + shape, HloOpcode::kSelect, c1.get(), c2.get(), c3.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise addition with 2 operands. +TEST_F(HloEvaluatorTest, DoesAdd) { + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1 = HloInstruction::CreateConstant(std::move(lhs)); + auto c2 = HloInstruction::CreateConstant(std::move(rhs)); + auto instruction = + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1.get(), c2.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise divide with 2 operands. +TEST_F(HloEvaluatorTest, DoesDivide) { + auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + + Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1_s64 = HloInstruction::CreateConstant(std::move(lhs_s64)); + auto c2_s64 = HloInstruction::CreateConstant(std::move(rhs_s64)); + auto instruction = HloInstruction::CreateBinary(shape_s64, HloOpcode::kDivide, + c1_s64.get(), c2_s64.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + + auto lhs_f64 = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + + Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); + auto c1_f64 = HloInstruction::CreateConstant(std::move(lhs_f64)); + auto c2_f64 = HloInstruction::CreateConstant(std::move(rhs_f64)); + instruction = HloInstruction::CreateBinary(shape_f64, HloOpcode::kDivide, + c1_f64.get(), c2_f64.get()); + + result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + expected = + LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise abs op with 1 operand. +TEST_F(HloEvaluatorTest, DoesAbs) { + auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + auto c1 = HloInstruction::CreateConstant(std::move(operand)); + auto instruction = + HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1.get()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor +// constant operands. +TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { + HloComputation::Builder builder( + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + + auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + + auto param_lhs = HloInstruction::CreateParameter(0, shape, "lhs"); + auto param_rhs = HloInstruction::CreateParameter(1, shape, "rhs"); + auto lhs_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs.get(), param_rhs.get()); + + auto param_rhs2 = HloInstruction::CreateParameter(2, shape, "rhs2"); + auto root_instruction = HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, lhs_instruction.get(), param_rhs2.get()); + + builder.AddInstruction(std::move(root_instruction)); + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); + + auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + + EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 447892c8dec9ea0549a35c9ea2b20303c52b9aa2..9e25f1aceb1595b89aee601b294792e9e801c6f3 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -70,6 +70,7 @@ string HloExecutionProfile::ToString( string result; const int64 total_cycles = total_cycles_executed(computation); double clock_rate_ghz = device_description.clock_rate_ghz(); + CHECK_GE(clock_rate_ghz, 1e-9); const auto cycles_to_microseconds = [&](double cycles) { return cycles / clock_rate_ghz / 1000.0; @@ -80,14 +81,19 @@ string HloExecutionProfile::ToString( double nsecs = cycles / clock_rate_ghz; string bytes_per_sec; string bytes_per_cycle; - if (bytes_accessed >= 0) { + if (cycles <= 0 || bytes_accessed < 0) { + bytes_per_sec = ""; + bytes_per_cycle = ""; + } else { bytes_per_sec = tensorflow::strings::HumanReadableNumBytes( bytes_accessed / (nsecs / 1e9)); bytes_per_cycle = tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles); - } else { - bytes_per_sec = ""; - bytes_per_cycle = ""; + } + + double cycles_percent = 0; + if (total_cycles > 0) { + cycles_percent = cycles / static_cast(total_cycles) * 100; } tensorflow::strings::StrAppend( @@ -97,8 +103,7 @@ string HloExecutionProfile::ToString( ":: " "%12s/cycle :: " "%s", - cycles, cycles / static_cast(total_cycles) * 100, - cycles_to_microseconds(cycles), + cycles, cycles_percent, cycles_to_microseconds(cycles), flops <= 0 ? "" : HumanReadableNumFlops(flops, nsecs).c_str(), bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str())); }; @@ -114,26 +119,30 @@ string HloExecutionProfile::ToString( for (const auto& item : items) { const HloInstruction* hlo = item.first; tensorflow::strings::StrAppend(&result, "\n\t"); - int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo); - int64 bytes_accessed = - hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo); - string display = hlo == nullptr ? "" : hlo->ToString(); + const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo); + const int64 bytes_accessed = + (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo); + const string display = (hlo == nullptr) ? "" : hlo->ToString(); append_item(item.second, flops, bytes_accessed, display); } - MetricTableReport table; - table.SetMetricName("microseconds"); - table.SetEntryName("ops"); - table.SetShowCategoryTable(); - for (const auto& item : items) { - MetricTableReport::Entry entry; - entry.text = item.first->ToString(); - entry.short_text = item.first->ToString(/*compact_operands=*/true); - entry.category_text = item.first->ToCategory(); - entry.metric = cycles_to_microseconds(item.second); - table.AddEntry(std::move(entry)); + if (total_cycles <= 0) { + result += "****** 0 total cycles ******\n"; + } else { + MetricTableReport table; + table.SetMetricName("microseconds"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& item : items) { + MetricTableReport::Entry entry; + entry.text = item.first->ToString(); + entry.short_text = item.first->ToString(/*compact_operands=*/true); + entry.category_text = item.first->ToCategory(); + entry.metric = cycles_to_microseconds(item.second); + table.AddEntry(std::move(entry)); + } + result += table.MakeReport(cycles_to_microseconds(total_cycles)); } - result += table.MakeReport(cycles_to_microseconds(total_cycles)); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0af4c99d0a51ab6e4d3048abae1b9c3fb6dca5e6..eb2e5dfb37f33fd138e20ee930a2242cb1db89ea 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -47,6 +48,73 @@ namespace xla { namespace hlo_graph_dumper { namespace { +// Node color schemes, used by NodeColorAttributes. +enum ColorScheme { + kBlue, + kBrown, + kDarkBlue, + kDarkGreen, + kDarkRed, + kGray, + kGreen, + kOrange, + kPurple, + kRed, + kWhite, + kYellow, +}; + +// Given a ColorScheme, returns an attribute string for a node of that color. +// Sets the node's fill, stroke, and text colors. +// +// Colors are from https://material.io/color. +string NodeColorAttributes(ColorScheme color) { + using std::make_tuple; + + const char *fill_color, *stroke_color, *font_color; + std::tie(fill_color, stroke_color, font_color) = + [color]() -> std::tuple { + switch (color) { + case kBlue: + return make_tuple("#bbdefb", "#8aacc8", "black"); + case kBrown: + return make_tuple("#bcaaa4", "#8c7b75", "black"); + case kDarkBlue: + return make_tuple("#1565c0", "#003c8f", "white"); + case kDarkGreen: + return make_tuple("#2e7d32", "#005005", "white"); + case kDarkRed: + return make_tuple("#b71c1c", "#7f0000", "white"); + case kGray: + return make_tuple("#cfd8dc", "#9ea7aa", "black"); + case kGreen: + return make_tuple("#c8e6c9", "#97b498", "black"); + case kOrange: + return make_tuple("#ffe0b2", "#cbae82", "black"); + case kPurple: + return make_tuple("#e1bee7", "#af8eb5", "black"); + case kRed: + return make_tuple("#ffcdd2", "#cb9ca1", "black"); + case kWhite: + return make_tuple("white", "black", "black"); + case kYellow: + return make_tuple("#fff9c4", "#cbc693", "black"); + } + }(); + + return Printf( + "style=filled, fontcolor=\"%s\", color=\"%s\", fillcolor=\"%s\"", + font_color, stroke_color, fill_color); +} + +// Replaces <> with <>, so that this string is safe(er) for use in a +// graphviz HTML-like string. +string HtmlLikeStringSanitize(tensorflow::StringPiece s) { + return tensorflow::str_util::StringReplace( + tensorflow::str_util::StringReplace(s, "<", "<", /*replace_all=*/true), + ">", ">", /*replace_all=*/true); +} + // Returns the dot graph identifier for the given instruction. string InstructionId(const HloInstruction* instruction) { return Printf("%lld", reinterpret_cast(instruction)); @@ -101,30 +169,36 @@ string InstructionSequenceGraph( param_ports.push_back( Printf("<%s> %s", InstructionId(param).c_str(), label.c_str())); } - StrAppend(&graph_body, param_node_name, - " [shape=record,style=filled,fillcolor=\"lightblue1\",", - "label=\"{parameters | {", Join(param_ports, "|"), "}}\"];\n"); + // (If we wanted the word "parameters" to be bold like the other op names, + // we'd have to make this into an HTML-like table. It is possible but + // complicated; see http://www.graphviz.org/doc/info/shapes.html#html.) + StrAppend(&graph_body, param_node_name, " [shape=record ", + NodeColorAttributes(kOrange), "label=\"{parameters | {", + Join(param_ports, "|"), "}}\"];\n"); } for (auto& instruction : instructions) { - string color = "peachpuff"; - string shape = "ellipse"; - string name = instruction->ExtendedOpcodeStr(); + ColorScheme color = kYellow; + string shape = "box"; + string name = + StrCat("", HtmlLikeStringSanitize(instruction->ExtendedOpcodeStr()), + " ", HtmlLikeStringSanitize(instruction->name())); if (HloOpcode::kConvolution == instruction->opcode()) { - name += ":\\n" + instruction->ConvolutionDimensionNumbersToString() + - "\\n" + window_util::ToString(instruction->window()); + StrAppend( + &name, "
", + HtmlLikeStringSanitize( + instruction->ConvolutionDimensionNumbersToString()), + "
", + HtmlLikeStringSanitize(window_util::ToString(instruction->window()))); } - name += "\\n" + instruction->name(); - if (!instruction->metadata().op_type().empty()) { - StrAppend(&name, "\\n", instruction->metadata().op_type()); - } if (!instruction->metadata().op_name().empty()) { - StrAppend(&name, "\\n", instruction->metadata().op_name()); + StrAppend(&name, "
", + HtmlLikeStringSanitize(instruction->metadata().op_name())); } if (!instruction->metadata().source_file().empty() && instruction->metadata().source_line() != 0) { - StrAppend(&name, "\\n", instruction->metadata().source_file(), ":", + StrAppend(&name, "
", instruction->metadata().source_file(), ":", instruction->metadata().source_line()); } @@ -139,11 +213,8 @@ string InstructionSequenceGraph( case HloOpcode::kAdd: case HloOpcode::kCeil: case HloOpcode::kClamp: - case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -162,64 +233,49 @@ string InstructionSequenceGraph( case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: - case HloOpcode::kPad: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kReshape: - case HloOpcode::kReverse: case HloOpcode::kSelect: case HloOpcode::kSign: case HloOpcode::kSlice: case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kTuple: - case HloOpcode::kUpdate: - break; - - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - StrAppend(&name, "\\n", "dims={", Join(instruction->dimensions(), ","), - "}"); - break; - case HloOpcode::kGetTupleElement: - StrAppend(&name, "\\nindex=", instruction->tuple_index()); break; case HloOpcode::kRng: - StrAppend(&name, "\\n", + StrAppend(&name, "
", RandomDistribution_Name(instruction->random_distribution())); break; - case HloOpcode::kConstant: - shape = "boxed"; - color = "palegreen"; - if (ShapeUtil::IsScalar(instruction->shape())) { - StrAppend(&name, "\\n", "value=", LiteralUtil::GetAsString( - instruction->literal(), {})); - } + case HloOpcode::kBroadcast: + case HloOpcode::kTranspose: + StrAppend(&name, "
", "dims={", + Join(instruction->dimensions(), ","), "}"); break; case HloOpcode::kBitcast: - case HloOpcode::kCopy: - color = "white"; - break; - case HloOpcode::kCall: - color = "tomato"; - break; - case HloOpcode::kCustomCall: - color = "tomato4"; - StrAppend(&name, "\\n", - "custom_call_target=", instruction->custom_call_target()); + case HloOpcode::kTuple: + case HloOpcode::kTrace: + color = kWhite; break; - case HloOpcode::kDot: - color = "slateblue"; + case HloOpcode::kGetTupleElement: + color = kWhite; + StrAppend(&name, "
index=", instruction->tuple_index()); break; - case HloOpcode::kSend: - color = "purple"; + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kUpdate: + color = kGreen; break; - case HloOpcode::kRecv: - color = "orange"; + case HloOpcode::kConstant: + color = kBlue; break; - case HloOpcode::kMap: - color = "palevioletred"; + case HloOpcode::kConvolution: + case HloOpcode::kDot: + color = kDarkBlue; break; case HloOpcode::kParameter: // A single record node is created for all the parameter nodes with a @@ -228,38 +284,54 @@ string InstructionSequenceGraph( continue; case HloOpcode::kReduce: StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); - color = "lightsalmon"; + color = kPurple; break; case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: - color = "lightsalmon"; - break; - case HloOpcode::kTrace: - color = "white"; + color = kPurple; break; case HloOpcode::kWhile: - color = "forestgreen"; + shape = "ellipse"; + color = kDarkGreen; break; + case HloOpcode::kMap: case HloOpcode::kFusion: - color = "gray"; - break; - case HloOpcode::kConvolution: - color = "red"; - break; - case HloOpcode::kCrossReplicaSum: - color = "turquoise"; + color = kGray; break; + case HloOpcode::kSend: + case HloOpcode::kRecv: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - color = "blue"; + case HloOpcode::kCrossReplicaSum: + color = kBrown; + break; + case HloOpcode::kCall: + color = kDarkGreen; + break; + case HloOpcode::kCustomCall: + color = kDarkGreen; + StrAppend(&name, "
", + "custom_call_target=", instruction->custom_call_target()); break; } // Create instruction node with appropriate label, shape, and color. + // label is interpreted as an HTML-like string, so newlines must be + // delimited with
, rather than \n. string label = - StrCat(name, "\\n", ShapeUtil::HumanString(instruction->shape())); + StrCat(name, "
", ShapeUtil::HumanString(instruction->shape())); + + if (instruction->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instruction->shape())) { + auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( + instruction->shape(), /*linear_index=*/0); + StrAppend(&label, " = {", + LiteralUtil::GetAsString(instruction->literal(), elem_idx), + "}"); + } + if (show_addresses) { - Appendf(&label, "\\n[%p]", instruction.get()); + Appendf(&label, "
[%p]", instruction.get()); } if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) { string layout_string; @@ -271,7 +343,7 @@ string InstructionSequenceGraph( layout_string = Join(instruction->shape().layout().minor_to_major(), ","); } - StrAppend(&label, "\\nlayout={", layout_string, "}"); + StrAppend(&label, "
layout={", layout_string, "}"); } if (hlo_execution_profile != nullptr) { auto hlo_cycles_executed = @@ -279,16 +351,16 @@ string InstructionSequenceGraph( auto total_cycles_executed = hlo_execution_profile->total_cycles_executed(*instruction->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { - Appendf(&label, "\\n%% of cycles executed=%.2f", + Appendf(&label, "
%% of cycles executed=%.2f", (static_cast(hlo_cycles_executed) / static_cast(total_cycles_executed)) * 100); } } - Appendf(&graph_body, - "%s [label=\"%s\", shape=%s, style=filled, fillcolor=%s];\n", + + Appendf(&graph_body, "%s [label=<%s>, shape=%s, %s];\n", InstructionId(instruction.get()).c_str(), label.c_str(), - shape.c_str(), color.c_str()); + shape.c_str(), NodeColorAttributes(color).c_str()); // Create edges from the instruction's operands to the instruction. int64 operand_number = 0; @@ -318,7 +390,7 @@ string InstructionSequenceGraph( StrCat("cluster_", InstructionId(instruction.get())); StrAppend(&graph_body, "subgraph ", cluster_name, " {\n"); StrAppend(&graph_body, - "label=\"fused expression\";\nstyle=filled;\n" + "label=<fused expression>;\nstyle=\"rounded,filled\";\n" "color=lightgrey;\n"); StrAppend(&graph_body, InstructionSequenceGraph( instruction->fused_instructions(), @@ -348,19 +420,39 @@ string InstructionSequenceGraph( return graph_body; } +// DOT graphs accept a stylesheet as a URL. So naturally, an inline stylesheet +// is a data URI! +// +// We don't perform any escaping on this string, so be careful not to use double +// quotes inside. +static const char* dot_stylesheet = R"( +data:text/css, +@import url(https://fonts.googleapis.com/css?family=Roboto:400,700); +svg text { + font-family: 'Roboto'; + font-size: 12px; +} +)"; + string ComputationToDotGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph_label = StrCat(label, "\\n", computation.name()); + string graph_label = StrCat(label, "
", computation.name()); if (hlo_execution_profile != nullptr) { auto cycles = hlo_execution_profile->total_cycles_executed(computation); - Appendf(&graph_label, "\\ntotal cycles = %lld (%s)", cycles, + Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, tensorflow::strings::HumanReadableNum(cycles).c_str()); } - string graph = - Printf("digraph G {\nrankdir=TB;\ncompound=true;\nlabel=\"%s\"\n", - graph_label.c_str()); + string graph = Printf( + R"(digraph G { +rankdir=TB; +compound=true; +label=<%s>; +labelloc=t; +stylesheet="%s" +)", + graph_label.c_str(), dot_stylesheet); // Emit embedded computations as subgraph clusters. std::vector intercomputation_edges; @@ -368,7 +460,9 @@ string ComputationToDotGraph(const HloComputation& computation, string graph_body = InstructionSequenceGraph( embedded->instructions(), show_addresses, show_layouts, &intercomputation_edges, hlo_execution_profile); - Appendf(&graph, "subgraph cluster_%s {\nlabel=\"%s\";\n%s}\n", + Appendf(&graph, + "subgraph cluster_%s " + "{\nstyle=rounded;label=<%s>;labelloc=t;\n%s}\n", ComputationId(embedded).c_str(), embedded->name().c_str(), graph_body.c_str()); } @@ -414,14 +508,24 @@ namespace { class FileGraphRenderer : public GraphRendererInterface { public: - string RenderGraph(const string& graph) override { + string RenderGraph(const string& graph, GraphKind graph_kind) override { static std::atomic output_num(0); legacy_flags::HloGraphDumperFlags* flags = legacy_flags::GetHloGraphDumperFlags(); - string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_", - output_num++, ".XXXXXX.dot"); + string file_extension; + switch (graph_kind) { + case DOT_GRAPH: + file_extension = ".dot"; + break; + case TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = + JoinPath(flags->xla_hlo_dump_graph_path, + StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); auto status = Status::OK(); - int fd = mkstemps(&path[0], 4); + int fd = mkstemps(&path[0], file_extension.length()); if (fd < 0) { status = Status(tensorflow::error::Code::UNKNOWN, @@ -446,10 +550,26 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph = ComputationToDotGraph(computation, label, show_addresses, - show_layouts, hlo_execution_profile); - - string graph_url = GetGraphRenderer()->RenderGraph(graph); + string graph; + string graph_url; + legacy_flags::HloGraphDumperFlags* flags = + legacy_flags::GetHloGraphDumperFlags(); + if (flags->xla_hlo_dump_as_graphdef) { + HloTfGraphBuilder builder; + TF_CHECK_OK(builder.AddComputation(computation)); + CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), + &graph)); + // TODO(b/37198616): Use the default registered renderers when all + // renderers support rendering GraphDefs. Always dump GraphDefs to files + // for now. + graph_url = FileGraphRenderer().RenderGraph( + graph, GraphRendererInterface::TF_GRAPHDEF); + } else { + graph = ComputationToDotGraph(computation, label, show_addresses, + show_layouts, hlo_execution_profile); + graph_url = GetGraphRenderer()->RenderGraph( + graph, GraphRendererInterface::DOT_GRAPH); + } LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; @@ -467,5 +587,4 @@ void DumpText(const HloModule& module, const string& label, } } // namespace hlo_graph_dumper - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 5f841da1f35c40042fde54dbc03eb7682a8d31cb..8ed50c38473a6f6dd36603e155285e855ff0c5be 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -25,8 +25,25 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { -// Dumps a graph of the computation to the GraphViz server and returns -// a description of the rendered graph (e.g., a URL). +// Abstract interface for classes that render HLO graphs (e.g. DOT graph, +// tensorflow GraphDef). +class GraphRendererInterface { + public: + enum GraphKind { + DOT_GRAPH, + TF_GRAPHDEF, + }; + + virtual ~GraphRendererInterface() = default; + + // Renders a DOT graph, returning a description of the rendered output + // (e.g., a URL) + virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; +}; + +// Dumps a graph of the computation and returns a description of the rendered +// graph (e.g., a URL) based on the renderer. The "best" renderer in the +// registry is used. string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile = nullptr); @@ -40,16 +57,6 @@ string DumpGraph(const HloComputation& computation, const string& label, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); -// Abstract interface for classes that render DOT graphs. -class GraphRendererInterface { - public: - virtual ~GraphRendererInterface() = default; - - // Renders a DOT graph, returning a description of the rendered output - // (e.g., a URL) - virtual string RenderGraph(const string& graph) = 0; -}; - // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 905647c2ed9f30dca31dccd07b6bcf99479ae2aa..bfb2129e13cd22cabb466ca383ce7b6ead90e96f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -41,9 +44,10 @@ limitations under the License. namespace xla { -using ::tensorflow::strings::StrAppend; using ::tensorflow::str_util::Join; using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { @@ -209,10 +213,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); if (window_util::HasBaseDilation(window)) { - instruction->set_name(instruction->name() + "-base-dilated"); + instruction->name_ = instruction->name() + "-base-dilated"; } if (window_util::HasWindowDilation(window)) { - instruction->set_name(instruction->name() + "-window-dilated"); + instruction->name_ = instruction->name() + "-window-dilated"; } instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); @@ -406,7 +410,9 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateReshape( const Shape& shape, HloInstruction* operand) { CHECK_EQ(ShapeUtil::ElementsIn(shape), - ShapeUtil::ElementsIn(operand->shape())); + ShapeUtil::ElementsIn(operand->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operand->shape()); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; @@ -432,6 +438,7 @@ HloInstruction::CreateSelectAndScatter( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; instruction->set_parent(fused_root->parent()); + instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); instruction->CheckFusionInstruction(); return instruction; @@ -497,14 +504,20 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(instruction_to_fuse->IsFusable()); - bool new_fusion_instruction = fused_instructions_.empty(); - fused_instructions_.emplace_back(instruction_to_fuse->Clone()); - HloInstruction* clone = fused_instructions_.back().get(); - clone->parent_fusion_instruction_ = this; - - if (new_fusion_instruction) { - fused_root_ = clone; + HloInstruction* clone = nullptr; + if (fused_instructions_computation_ == nullptr) { + // New fusion instruction. + auto builder = HloComputation::Builder("fused_computation", true); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + fused_instructions_computation_ = builder.Build(); + clone = fused_expression_root(); + clone->parent_fusion_instruction_ = this; } else { + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + clone = fused_instructions_computation_->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + clone->parent_fusion_instruction_ = this; // instruction_to_fuse is necessarily an operand of the fusion instruction. // After fusion this will no longer be the case. Remove the operand from the // operand list and remove its corresponding fused parameter @@ -512,6 +525,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // consistent with their index in the fused_parameter_ vector. CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) != operands_.end()); + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { if (instruction_to_fuse == operands_[operand_num]) { // replace the fused parameter instruction's uses with the clone. @@ -520,22 +535,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Remove the corresponding fused parameter and operand from their // respective vectors. - fused_parameters_.erase(fused_parameters_.begin() + operand_num); + TF_CHECK_OK( + fused_instructions_computation_->RemoveParameter(operand_num)); operands_.erase(operands_.begin() + operand_num); - - // Renumber fused parameter numbers to match the vector index. - while (operand_num < fused_parameters_.size()) { - fused_parameters_[operand_num]->parameter_number_ = operand_num; - operand_num++; - } - // Throw removed fused parameter instruction away. - auto inst_it = - std::find_if(fused_instructions_.begin(), fused_instructions_.end(), - [=](const std::unique_ptr& inst) { - return inst.get() == fused_parameter; - }); - CHECK(inst_it != fused_instructions_.end()); - fused_instructions_.erase(inst_it); break; } } @@ -544,6 +546,10 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( instruction_to_fuse->RemoveUser(this); } + // Reread the parameters in the computation. + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); + // Add each operand of the clone as an operand of the fusion instruction. A // complication is that some clone operands may already be operands of the // fusion instruction. @@ -566,16 +572,18 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // instruction. Add it as an operand and add a corresponding fused // parameter instruction. int64 param_no = fused_parameters_.size(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. Strip the leading "%" from the operand name + // to avoid a double %%. + string param_name = + StrCat(operand->name().substr(1), ".param_", param_no); std::unique_ptr param_instruction = - CreateParameter(param_no, operand->shape(), "fusion_param"); + CreateParameter(param_no, operand->shape(), param_name); - param_instruction->set_parent(parent()); param_instruction->parent_fusion_instruction_ = this; - fused_parameters_.push_back(param_instruction.get()); - fused_instructions_.push_back(std::move(param_instruction)); + fused_param = fused_instructions_computation_->AddParameter( + std::move(param_instruction)); AppendOperand(operand); - - fused_param = fused_instructions_.back().get(); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } @@ -598,18 +606,25 @@ RandomDistribution HloInstruction::random_distribution() const { void HloInstruction::CheckFusionInstruction() const { CHECK_EQ(opcode_, HloOpcode::kFusion); + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + const std::list>& fused_instructions_ = + fused_instructions_computation_->instructions(); // All instructions owned by this fusion instruction must be fused, and the // parent fusion instruction of the fused instructions must be 'this'. for (auto& instruction : fused_instructions_) { CHECK(instruction->IsFused()); CHECK_EQ(this, instruction->fusion_instruction()); - CHECK_EQ(parent(), instruction->parent()) << instruction->ToString(); + CHECK_EQ(fused_instructions_computation_.get(), instruction->parent()) + << instruction->ToString(); } // Fused root instruction and fused parameters must all be owned by the fusion // instruction. bool root_owned = false; + const std::vector& fused_parameters_ = fused_parameters(); + const HloInstruction* fused_root_ = fused_expression_root(); std::vector parameter_owned(fused_parameters_.size(), false); for (auto& instruction : fused_instructions_) { if (fused_root_ == instruction.get()) { @@ -702,7 +717,8 @@ void HloInstruction::CheckFusionInstruction() const { } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands) { // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -721,8 +737,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSign: case HloOpcode::kSort: case HloOpcode::kTanh: - CHECK_EQ(operands.size(), 1); - return CreateUnary(shape, opcode_, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateUnary(shape, opcode_, new_operands[0]); // Binary ops. case HloOpcode::kAdd: case HloOpcode::kDivide: @@ -741,93 +757,92 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kLogicalAnd: case HloOpcode::kLogicalOr: - CHECK_EQ(operands.size(), 2); - return CreateBinary(shape, opcode_, operands[0], operands[1]); + CHECK_EQ(new_operands.size(), 2); + return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: - CHECK_EQ(operands.size(), 3); - return CreateTernary(shape, opcode_, operands[0], operands[1], - operands[2]); + CHECK_EQ(new_operands.size(), 3); + return CreateTernary(shape, opcode_, new_operands[0], new_operands[1], + new_operands[2]); // Other supported ops. case HloOpcode::kBroadcast: - CHECK_EQ(operands.size(), 1); - return CreateBroadcast(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateBroadcast(shape, new_operands[0], dimensions_); case HloOpcode::kCall: - return CreateCall(shape, operands, to_apply()); + return CreateCall(shape, new_operands, to_apply()); case HloOpcode::kCustomCall: - return CreateCustomCall(shape, operands, custom_call_target_); + return CreateCustomCall(shape, new_operands, custom_call_target_); case HloOpcode::kConcatenate: - return CreateConcatenate(shape, operands, dimensions(0)); + return CreateConcatenate(shape, new_operands, dimensions(0)); case HloOpcode::kConvert: - CHECK_EQ(operands.size(), 1); - return CreateConvert(shape, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateConvert(shape, new_operands[0]); case HloOpcode::kConvolution: - CHECK_EQ(operands.size(), 2); - return CreateConvolve(shape, operands[0], operands[1], *window_, + CHECK_EQ(new_operands.size(), 2); + return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); case HloOpcode::kCrossReplicaSum: - CHECK_EQ(operands.size(), 1); - return CreateCrossReplicaSum(shape, operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateCrossReplicaSum(shape, new_operands[0]); case HloOpcode::kGetTupleElement: - CHECK_EQ(operands.size(), 1); - return CreateGetTupleElement(shape, operands[0], tuple_index()); + CHECK_EQ(new_operands.size(), 1); + return CreateGetTupleElement(shape, new_operands[0], tuple_index()); case HloOpcode::kMap: - return CreateMap(shape, operands, to_apply()); + return CreateMap(shape, new_operands, to_apply()); case HloOpcode::kPad: - CHECK_EQ(operands.size(), 2); - return CreatePad(shape, operands[0], operands[1], *padding_config_); + CHECK_EQ(new_operands.size(), 2); + return CreatePad(shape, new_operands[0], new_operands[1], + *padding_config_); case HloOpcode::kReduce: - CHECK_EQ(operands.size(), 2); - return CreateReduce(shape, operands[0], operands[1], dimensions_, + CHECK_EQ(new_operands.size(), 2); + return CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, to_apply()); case HloOpcode::kReduceWindow: - CHECK_EQ(operands.size(), 2); - return CreateReduceWindow(shape, operands[0], operands[1], *window_, - to_apply()); + CHECK_EQ(new_operands.size(), 2); + return CreateReduceWindow(shape, new_operands[0], new_operands[1], + *window_, to_apply()); case HloOpcode::kSelectAndScatter: - CHECK_EQ(operands.size(), 3); - return CreateSelectAndScatter(shape, operands[0], select(), *window_, - operands[1], operands[2], scatter()); - case HloOpcode::kRecv: - CHECK_EQ(operands.size(), 0); - return CreateRecv(shape, channel_id_); + CHECK_EQ(new_operands.size(), 3); + return CreateSelectAndScatter(shape, new_operands[0], select(), *window_, + new_operands[1], new_operands[2], + scatter()); case HloOpcode::kReverse: - CHECK_EQ(operands.size(), 1); - return CreateReverse(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateReverse(shape, new_operands[0], dimensions_); case HloOpcode::kRng: - return CreateRng(shape, distribution_, operands); + return CreateRng(shape, distribution_, new_operands); case HloOpcode::kReshape: - CHECK_EQ(operands.size(), 1); - return CreateReshape(shape, operands[0]); - case HloOpcode::kSend: - CHECK_EQ(operands.size(), 1); - return CreateSend(operands[0], channel_id_); + CHECK_EQ(new_operands.size(), 1); + return CreateReshape(shape, new_operands[0]); case HloOpcode::kSlice: - CHECK_EQ(operands.size(), 1); - return CreateSlice(shape, operands[0], slice_starts_, slice_limits_); + CHECK_EQ(new_operands.size(), 1); + return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_); case HloOpcode::kDynamicSlice: - return CreateDynamicSlice(shape, operands[0], operands[1], + return CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); case HloOpcode::kDynamicUpdateSlice: - CHECK_EQ(operands.size(), 3); - return CreateDynamicUpdateSlice(shape, operands[0], operands[1], - operands[2]); + CHECK_EQ(new_operands.size(), 3); + return CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], + new_operands[2]); case HloOpcode::kTranspose: - CHECK_EQ(operands.size(), 1); - return CreateTranspose(shape, operands[0], dimensions_); + CHECK_EQ(new_operands.size(), 1); + return CreateTranspose(shape, new_operands[0], dimensions_); case HloOpcode::kTuple: - return CreateTuple(operands_); + return CreateTuple(new_operands); case HloOpcode::kWhile: - CHECK_EQ(operands.size(), 1); - return CreateWhile(shape, while_condition(), while_body(), operands[0]); + CHECK_EQ(new_operands.size(), 1); + return CreateWhile(shape, while_condition(), while_body(), + new_operands[0]); case HloOpcode::kConstant: return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); case HloOpcode::kFusion: - return CloneFusionWithNewOperands(shape, operands); + return CloneFusionWithNewOperands(shape, new_operands); case HloOpcode::kParameter: return CreateParameter(parameter_number_, shape, parameter_name_); // Unsupported ops for cloning. + case HloOpcode::kRecv: + case HloOpcode::kSend: case HloOpcode::kUpdate: case HloOpcode::kIndex: case HloOpcode::kInfeed: @@ -837,11 +852,46 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } } +HloInstruction::~HloInstruction() {} + std::unique_ptr HloInstruction::Clone(const string& suffix) { std::unique_ptr clone = CloneWithNewOperands(shape_, operands_); - clone->name_ = name() + "." + suffix; + if (suffix.empty()) { + clone->name_ = name(); + } else { + // If an instruction is cloned multiple times avoid names like + // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric + // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the + // clone of foo.suffix2 is named foo.suffix3 and so on. + const string dot_suffix = "." + suffix; + size_t index = name().rfind(dot_suffix); + if (index == string::npos) { + // Existing name does not include ".suffix". + clone->name_ = name() + dot_suffix; + } else { + // Existing name includes ".suffix". Determine if substring after + // ".suffix" is numeric and should be replaced with an incremented number. + string after_suffix = name().substr(index + dot_suffix.size()); + if (after_suffix.empty()) { + // Existing name ends in ".suffix". New name should end in ".suffix2". + clone->name_ = name() + "2"; + } else { + // If names ends with .suffix[0-9]+ then replace with a suffix with the + // numeric value incremented. + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + clone->name_ = + StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); + } else { + // Substring after ".suffix" is non-numeric. + clone->name_ = name() + dot_suffix; + } + } + } + } clone->set_parent(parent()); + clone->set_metadata(metadata_); return clone; } @@ -849,6 +899,8 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); auto new_instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); @@ -862,6 +914,11 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( // Create the list of fused parameters by mapping through the cloned, // fused instructions. std::vector new_fused_parameters; + const std::vector& fused_parameters_ = + fused_instructions_computation_->parameter_instructions(); + const std::list>& fused_instructions_ = + fused_instructions_computation_->instructions(); + for (HloInstruction* old_fused_parameter : fused_parameters_) { new_fused_instructions.push_back(old_fused_parameter->Clone()); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); @@ -892,13 +949,19 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } + new_instruction->fusion_kind_ = fusion_kind_; + auto computation_builder = HloComputation::Builder( + fused_instructions_computation_->name() + ".clone", true); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. - std::reverse(new_fused_instructions.begin(), new_fused_instructions.end()); - new_instruction->fusion_kind_ = fusion_kind_; - new_instruction->fused_instructions_ = std::move(new_fused_instructions); - new_instruction->fused_parameters_ = std::move(new_fused_parameters); - new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_); + for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); + new_fused_instruction_iter != new_fused_instructions.rend(); + ++new_fused_instruction_iter) { + computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); + } + auto fused_root_ = fused_expression_root(); + new_instruction->fused_instructions_computation_ = + computation_builder.Build(FindOrDie(old_to_new, fused_root_)); new_instruction->set_parent(parent()); new_instruction->CheckFusionInstruction(); return new_instruction; @@ -1020,7 +1083,7 @@ bool HloInstruction::Identical( // general, there is no need to check shape because shape is inferred from the // shape of the operands. if (opcode() != other.opcode() || - !ContainersEqual(operands(), other.operands(), eq_operands)) { + !ContainersEqual(operands(), other.operands(), std::move(eq_operands))) { return false; } @@ -1355,8 +1418,7 @@ string HloInstruction::SignatureString() const { Join(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); - return tensorflow::strings::StrCat("(", operands, ") -> ", - ShapeUtil::HumanString(shape())); + return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } string HloInstruction::ExtendedOpcodeStr() const { @@ -1368,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const { return opc_name; } -string HloInstruction::ToString(bool compact_operands) const { +string HloInstruction::ToString(bool compact_operands, + bool include_metadata) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1390,6 +1453,8 @@ string HloInstruction::ToString(bool compact_operands) const { // Do not show large constants. operands = "{...}"; } + } else if (opcode() == HloOpcode::kParameter) { + operands = Printf("%lld", parameter_number_); } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; @@ -1420,8 +1485,8 @@ string HloInstruction::ToString(bool compact_operands) const { if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector bounds; for (int i = 0; i < slice_starts_.size(); ++i) { - bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":", - slice_limits_[i], "]")); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); } StrAppend(&extra, ", slice={", Join(bounds, ", "), "}"); } @@ -1447,10 +1512,12 @@ string HloInstruction::ToString(bool compact_operands) const { if (opcode() == HloOpcode::kGetTupleElement) { StrAppend(&extra, ", index=", tuple_index()); } - if (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty()) { + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } + return Printf("%s = %s %s(%s)%s", name().c_str(), ShapeUtil::HumanStringWithLayout(shape()).c_str(), ExtendedOpcodeStr().c_str(), operands.c_str(), extra.c_str()); @@ -1503,7 +1570,9 @@ string HloInstruction::ToCategory() const { return "non-elementwise fusion"; } case FusionKind::kInput: - return "reduce fusion"; + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; case FusionKind::kTransposeDot: return "dot fusion"; case FusionKind::kConvBackwardFilter: @@ -1521,11 +1590,10 @@ string HloInstruction::ToCategory() const { string HloInstruction::FullyQualifiedName() const { if (IsFused()) { - return tensorflow::strings::StrCat(fusion_instruction()->parent()->name(), - "::", fusion_instruction()->name(), - "::", name_); + return StrCat(fusion_instruction()->parent()->name(), + "::", fusion_instruction()->name(), "::", name_); } - return tensorflow::strings::StrCat(parent_->name(), "::", name_); + return StrCat(parent_->name(), "::", name_); } HloInstruction* HloInstruction::tracing() const { return trace_instruction_; } @@ -1552,7 +1620,6 @@ bool HloInstruction::IsFusable() const { // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kFusion: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kParameter: @@ -1569,6 +1636,11 @@ bool HloInstruction::IsFusable() const { } } +HloComputation* HloInstruction::fused_instructions_computation() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + return fused_instructions_computation_.get(); +} + HloInstruction* HloInstruction::fusion_instruction() const { CHECK(IsFused()); return parent_fusion_instruction_; @@ -1576,25 +1648,32 @@ HloInstruction* HloInstruction::fusion_instruction() const { HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_root_; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->root_instruction(); } HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_GE(parameter_number, 0); - CHECK_LT(parameter_number, fused_parameters_.size()); - return fused_parameters_[parameter_number]; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->parameter_instruction( + parameter_number); } const std::vector& HloInstruction::fused_parameters() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_parameters_; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->parameter_instructions(); } const std::list>& HloInstruction::fused_instructions() const { CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_; + CHECK(fused_instructions_computation_ != nullptr && + fused_instructions_computation_->IsFusionComputation()); + return fused_instructions_computation_->instructions(); } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -1703,7 +1782,7 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSlice: return visitor->HandleSlice(this, operands_[0]); case HloOpcode::kDynamicSlice: - return visitor->HandleDynamicSlice(this, operands_); + return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]); case HloOpcode::kDynamicUpdateSlice: return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1], operands_[2]); @@ -1716,12 +1795,11 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kRng: return visitor->HandleRng(this, distribution_); case HloOpcode::kWhile: - return visitor->HandleWhile(this, operands_[0], while_condition(), - while_body()); + return visitor->HandleWhile(this); case HloOpcode::kFusion: return visitor->HandleFusion(this); case HloOpcode::kCall: - return visitor->HandleCall(this, operands_, to_apply()); + return visitor->HandleCall(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this, operands_, custom_call_target_); case HloOpcode::kSend: @@ -1740,7 +1818,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { } Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order) { + const CompareFunction* operand_order, + bool ignore_control_predecessors) { // Do not visit this HLO node again if it is already visited. if (visitor->DidVisit(*this)) { VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; @@ -1755,34 +1834,41 @@ Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, } visitor->SetVisiting(*this); - // Sort operands and control predecessors, if an ordering was provided. Note - // that 'temp_sorted_operands' must live at this scope, since 'operands' will - // point to it if the operands are sorted. The point of the 'operands' - // pointer is to avoid copying the operands in the common case where the - // operands are not sorted. + // Sort operands, if an ordering was provided. 'temp_sorted_operands' must + // live at this scope, since 'operands' will point to it if the operands are + // sorted. The purpose of the 'operands' pointer is to avoid copying the + // operands in the common case where the operands are not sorted. std::vector* operands = &operands_; std::vector temp_sorted_operands; - std::vector predecessors(control_predecessors_.begin(), - control_predecessors_.end()); if (operand_order != nullptr) { temp_sorted_operands = operands_; std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), *operand_order); - std::sort(predecessors.begin(), predecessors.end(), *operand_order); operands = &temp_sorted_operands; } - - for (auto operand : *operands) { + for (HloInstruction* operand : *operands) { VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order)); - } - - for (auto control_predecessor : predecessors) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR( - control_predecessor->AcceptInternal(visitor, operand_order)); + TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, + ignore_control_predecessors)); + } + + if (!ignore_control_predecessors) { + // This uses the same pointer/vector sorting to avoid extra copies as above. + std::vector* predecessors = &control_predecessors_; + std::vector temp_sorted_predecessors; + if (operand_order != nullptr) { + temp_sorted_predecessors = control_predecessors_; + std::sort(temp_sorted_predecessors.begin(), + temp_sorted_predecessors.end(), *operand_order); + predecessors = &temp_sorted_predecessors; + } + for (HloInstruction* control_predecessor : *predecessors) { + VLOG(3) << "Going to visit HLO " << control_predecessor->name() + << " as a control predecessor of HLO " << name(); + TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( + visitor, operand_order, ignore_control_predecessors)); + } } TF_RETURN_IF_ERROR(visitor->Preprocess(this)); @@ -1792,9 +1878,11 @@ Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, return visitor->Postprocess(this); } -Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) { +Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, + bool ignore_control_predecessors) { VLOG(2) << "HloInstruction::Accept(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, nullptr)); + TF_RETURN_IF_ERROR( + AcceptInternal(visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -1805,7 +1893,8 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order)); + TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, + /*ignore_control_predecessors=*/false)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -2076,7 +2165,7 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return cache[&hlo]; }; - return reuses_parameter_elements(*fused_root_); + return reuses_parameter_elements(*fused_expression_root()); } default: return IsElementwise() ? UseKind::kUse : UseKind::kReuse; @@ -2098,6 +2187,8 @@ string ToString(HloInstruction::FusionKind kind) { return "kLoop"; case HloInstruction::FusionKind::kInput: return "kInput"; + case HloInstruction::FusionKind::kOutput: + return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; case HloInstruction::FusionKind::kConvBackwardFilter: @@ -2107,6 +2198,10 @@ string ToString(HloInstruction::FusionKind kind) { } } +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::ConvolutionDimensionNumbersToString() const { string result; if (convolution_dimension_numbers_ == nullptr) { @@ -2133,15 +2228,14 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { lhs_dims[dnums.batch_dimension()] = 'b'; lhs_dims[dnums.feature_dimension()] = 'f'; for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - lhs_dims[dnums.spatial_dimensions(i)] = tensorflow::strings::StrCat(i); + lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); } std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - rhs_dims[dnums.kernel_spatial_dimensions(i)] = - tensorflow::strings::StrCat(i); + rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } result += "dim_labels="; @@ -2164,4 +2258,15 @@ bool HloInstruction::CouldBeBitcast() const { } } +HloModule* HloInstruction::GetModule() const { + if (parent_) { + return parent_->parent(); + } + return nullptr; +} + +void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { + name_ = name_uniquer->GetUniqueName(name_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 6557ca9116312c4bc31b9f0ba734edd11106d1e7..d300d99adec5201b70b0fe4eb65ef5b84362b018 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -22,6 +22,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ #include +#include #include #include #include @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -45,18 +47,23 @@ limitations under the License. namespace xla { class HloComputation; +class HloModule; // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: enum class FusionKind { kLoop, // Fused into a loop. - kInput, // Fused into a reduction kernel. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. kTransposeDot, // Fused into a dot with transposed operands. kConvBackwardFilter, // Fused into a backward filter convolution. kConvBackwardInput, // Fused into a backward input convolution. }; + ~HloInstruction(); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, @@ -371,8 +378,12 @@ class HloInstruction { // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when - // complete. - Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true); + // complete. If ignore_control_predecessors is true, instructions only + // reachable via control dependencies will not be visited, and the postorder + // will not take control dependencies into account. It is as if the control + // dependencies didn't exist in the graph at all. + Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true, + bool ignore_control_predecessors = false); // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == @@ -418,6 +429,11 @@ class HloInstruction { return parameter_name_; } + void set_parameter_name(const string& str) { + CHECK_EQ(HloOpcode::kParameter, opcode_); + parameter_name_ = str; + } + // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, @@ -476,7 +492,10 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false) const; + string ToString(bool compact_operands = false, + bool include_metadata = true) const; + + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. string ToShortString() const; @@ -485,7 +504,9 @@ class HloInstruction { // or "elementwise". string ToCategory() const; - // Returns the string concatenation of parent name and this instructions name. + // Returns the string concatenation of parent name and this instructions + // name. This name is guaranteed to be unique among all instructions in the + // HloModule. string FullyQualifiedName() const; // Returns a logging instruction, if the output of this instruction is logged. @@ -534,6 +555,11 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; + // Returns the computation for this fused instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloComputation* fused_instructions_computation() const; + // Returns the vector of fused instructions inside this fusion // instruction. The order is a reverse postorder of the fused expression (root // is first in the order). @@ -704,8 +730,9 @@ class HloInstruction { // this instruction. const string& name() const { return name_; } - // Sets the string identifier for this instruction. - void set_name(const string& name) { name_ = name; } + // Use the given NameUniquer to select a unique name for the instruction based + // on the instruction's existing name. + void UniquifyName(NameUniquer* name_uniquer); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -718,10 +745,21 @@ class HloInstruction { const HloComputation* parent() const { return parent_; } HloComputation* parent() { return parent_; } + // Returns the module for this instruction. + HloModule* GetModule() const; + // Returns whether we could assign input and output layouts to this // instruction to make it a bitcast. bool CouldBeBitcast() const; + // Sets the parent fusion instruction for this instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + void SetParentFusion(HloInstruction* fusion_instruction) { + CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + parent_fusion_instruction_ = fusion_instruction; + } + private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; @@ -758,7 +796,8 @@ class HloInstruction { // Inner DFS traversal function -- this function being called (rather than // Accept above) allows us to distinguish the root of the traversal. Status AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order); + const CompareFunction* operand_order, + bool ignore_control_predecessors); // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -807,22 +846,14 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The set of instruction fused into this fusion instruction. Only set for - // fusion instructions. - std::list> fused_instructions_; + // The computation that stores of instructions fused into this fusion + // instruction. Only set for fusion instructions. + std::unique_ptr fused_instructions_computation_; // If this instruction is fused into a fusion instruction, this field points // to the fusion instruction. HloInstruction* parent_fusion_instruction_ = nullptr; - // The vector of parameter instructions inside this fusion instruction. The - // index of the vector is the parameter_number of the parameter instruction. - // This vector is non-empty only for fusion instructions. - std::vector fused_parameters_; - - // The root of the expression fused into this fusion instruction. - HloInstruction* fused_root_ = nullptr; - // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; @@ -898,6 +929,8 @@ class HloInstruction { string ToString(HloInstruction::FusionKind kind); +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8eabaa1c474aa068c423099919d3382f04c7591c..a226ab0d0c43e6df6216e4b0f58ed4270cb03d40 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -21,9 +21,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" @@ -31,6 +33,9 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + class HloInstructionTest : public HloTestBase { protected: HloInstructionTest() {} @@ -148,9 +153,9 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(), bar.get()); - ExpectEqOrdered(add->operands(), {foo.get(), bar.get()}); - ExpectEqUnordered(foo->users(), {add.get()}); - ExpectEqUnordered(bar->users(), {add.get()}); + EXPECT_THAT(add->operands(), UnorderedElementsAre(foo.get(), bar.get())); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add.get())); + EXPECT_THAT(bar->users(), UnorderedElementsAre(add.get())); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -383,12 +388,13 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { EXPECT_EQ(1, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - ExpectEqUnordered(foo->users(), {add_foobar.get()}); - ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()}); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar.get())); + EXPECT_THAT(add_foobar->operands(), ElementsAre(foo.get(), bar.get())); - ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()}); - ExpectEqOrdered(add_foobar->operands(), {foo.get(), bar.get()}); - ExpectEqOrdered(add_foofoo->operands(), {bar.get(), bar.get()}); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), add_foofoo.get())); + EXPECT_THAT(add_foobar->operands(), ElementsAre(foo.get(), bar.get())); + EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar.get(), bar.get())); } TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { @@ -404,16 +410,17 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { foo.get(), bar.get()); EXPECT_EQ(2, foo->user_count()); - ExpectEqUnordered(foo->users(), {tuple.get(), add_foobar.get()}); + EXPECT_THAT(foo->users(), + UnorderedElementsAre(tuple.get(), add_foobar.get())); // Replace the use of foo in tuple with bar. ASSERT_IS_OK(foo->ReplaceUseWith(tuple.get(), bar.get())); - ExpectEqUnordered(foo->users(), {add_foobar.get()}); + EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar.get())); // Both uses of foo in tuple should have been replaced with bar. - ExpectEqOrdered(tuple->operands(), - {bar.get(), bar.get(), baz.get(), bar.get()}); + EXPECT_THAT(tuple->operands(), + ElementsAre(bar.get(), bar.get(), baz.get(), bar.get())); } TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { @@ -426,7 +433,7 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get()); EXPECT_EQ(2, foo->user_count()); - ExpectEqUnordered(foo->users(), {exp.get(), log.get()}); + EXPECT_THAT(foo->users(), UnorderedElementsAre(exp.get(), log.get())); EXPECT_EQ(0, bar->user_count()); // Replace the use of foo in exp with bar. @@ -434,8 +441,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { // The use of foo in log should not have been affected. EXPECT_EQ(1, foo->user_count()); - ExpectEqUnordered(foo->users(), {log.get()}); - ExpectEqOrdered(log->operands(), {foo.get()}); + EXPECT_THAT(foo->users(), UnorderedElementsAre(log.get())); + EXPECT_THAT(log->operands(), ElementsAre(foo.get())); // Bar should now be used in exp. EXPECT_EQ(1, bar->user_count()); @@ -466,7 +473,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(2, bar->user_count()); - ExpectEqUnordered(bar->users(), {add_foobar.get(), add_foofoo.get()}); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), add_foofoo.get())); } TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { @@ -490,7 +498,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { EXPECT_EQ(0, foo->user_count()); EXPECT_EQ(3, bar->user_count()); - ExpectEqUnordered(bar->users(), {add_foobar.get(), exp.get(), tuple.get()}); + EXPECT_THAT(bar->users(), + UnorderedElementsAre(add_foobar.get(), exp.get(), tuple.get())); } // Simple visitor that collects and post-processes each node in the graph. @@ -558,8 +567,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, exp.get()); - ExpectEqOrdered(fusion->operands(), {constant.get()}); - ExpectEqUnordered(constant->users(), {fusion.get(), exp.get()}); + EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); + EXPECT_THAT(constant->users(), UnorderedElementsAre(fusion.get(), exp.get())); } TEST_F(HloInstructionTest, BinaryFusionOp) { @@ -574,9 +583,12 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { auto fusion = HloInstruction::CreateFusion( r0f32_, HloInstruction::FusionKind::kLoop, add.get()); - ExpectEqOrdered(fusion->operands(), {constant1.get(), constant2.get()}); - ExpectEqUnordered(constant1->users(), {fusion.get(), add.get()}); - ExpectEqUnordered(constant2->users(), {fusion.get(), add.get()}); + EXPECT_THAT(fusion->operands(), + ElementsAre(constant1.get(), constant2.get())); + EXPECT_THAT(constant1->users(), + UnorderedElementsAre(fusion.get(), add.get())); + EXPECT_THAT(constant2->users(), + UnorderedElementsAre(fusion.get(), add.get())); } TEST_F(HloInstructionTest, ChainFusionOp) { @@ -593,8 +605,28 @@ TEST_F(HloInstructionTest, ChainFusionOp) { fusion->FuseInstruction(exp2.get()); fusion->FuseInstruction(exp1.get()); - ExpectEqOrdered(fusion->operands(), {constant.get()}); - ExpectEqUnordered(constant->users(), {fusion.get(), exp1.get()}); + EXPECT_THAT(fusion->operands(), ElementsAre(constant.get())); + EXPECT_THAT(constant->users(), + UnorderedElementsAre(fusion.get(), exp1.get())); +} + +TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { + // Create a chain of fused unary ops. + auto constant = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto exp1 = + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); + auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); + OpMetadata metadata; + metadata.set_op_name("tf_op"); + exp1->set_metadata(metadata); + exp2->set_metadata(metadata); + + auto fusion = HloInstruction::CreateFusion( + r0f32_, HloInstruction::FusionKind::kLoop, exp2.get()); + auto* fused = fusion->FuseInstruction(exp1.get()); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata())); } TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { @@ -626,15 +658,15 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { auto fusion = HloInstruction::CreateFusion( scalar_shape, HloInstruction::FusionKind::kLoop, map_3_y.get()); - ASSERT_EQ(fusion->called_computations().size(), 1); - EXPECT_EQ(fusion->called_computations()[0], computation_y.get()); + EXPECT_THAT(fusion->called_computations(), ElementsAre(computation_y.get())); fusion->FuseInstruction(map_2_x.get()); - ASSERT_EQ(fusion->called_computations().size(), 2); - EXPECT_EQ(fusion->called_computations()[1], computation_x.get()); + EXPECT_THAT(fusion->called_computations(), + ElementsAre(computation_y.get(), computation_x.get())); fusion->FuseInstruction(map_1_x.get()); - ASSERT_EQ(fusion->called_computations().size(), 2); + EXPECT_THAT(fusion->called_computations(), + ElementsAre(computation_y.get(), computation_x.get())); } TEST_F(HloInstructionTest, ComplexFusionOp) { @@ -675,8 +707,9 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // Operands in the fusion instruction's operands() vector should be in the // order in which their users were added fused. - ExpectEqOrdered(fusion->operands(), {c1.get(), c3.get(), c2.get()}); - ExpectEqUnordered(c1->users(), {add.get(), tuple.get(), fusion.get()}); + EXPECT_THAT(fusion->operands(), ElementsAre(c1.get(), c3.get(), c2.get())); + EXPECT_THAT(c1->users(), + UnorderedElementsAre(add.get(), tuple.get(), fusion.get())); } // Convenience function for comparing two HloInstructions inside of @@ -929,5 +962,44 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } +TEST_F(HloInstructionTest, CloneSuffixNames) { + // Test that the suffix string added to cloned instructions is not + // duplicated. Rather a numeric incrementing value should be appended. That + // is, we want "foo.clone2", not "foo.clone.clone". + + // Test cloning the same instruction multiple times. + auto foo = + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); + EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); + + // Test custom suffixes. + EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), + "%foo.bar2.clone"); + + // Test instruction name with a dot. + auto foo_baz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); + EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); + + // Test incrementing a large number after the suffix. + auto foo_clone234 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); + EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); + + // Test a non-numeric string after the cloning suffix. + auto foo_clonexyz = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); + + // Test a name with multiple appearances of the suffix. + auto foo_clone_clone3 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc new file mode 100644 index 0000000000000000000000000000000000000000..e022c4836d87866925ab7e56c2250d87d0f5dfec --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace testing { + +bool HloMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + // These cases are self-explanatory from the printed value. + if (!instruction || instruction->opcode() != opcode_) { + return false; + } + // Special case: no operand matchers means don't verify. + if (operands_.empty()) { + return true; + } + const auto& operands = instruction->operands(); + if (operands.size() != operands_.size()) { + *listener << "has too " + << (operands.size() > operands_.size() ? "many" : "few") + << " operands (got " << operands.size() << ", want " + << operands_.size() << ")"; + return false; + } + for (int index = 0; index < operands.size(); index++) { + ::testing::StringMatchResultListener inner_listener; + if (!operands_[index].MatchAndExplain(operands[index], &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\noperand " << index << ":\n\t" + << operands[index]->ToString() + << "\ndoesn't match expected:\n\t"; + operands_[index].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + } + return true; +} + +void HloMatcher::DescribeTo(::std::ostream* os) const { + *os << opcode_; + if (!operands_.empty()) { + *os << "("; + for (int i = 0; i < operands_.size(); i++) { + if (i > 0) { + *os << ", "; + } + operands_[i].DescribeTo(os); + } + *os << ")"; + } +} + +} // namespace testing +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h new file mode 100644 index 0000000000000000000000000000000000000000..141251011cc0b4205b6069ff90415492ead9f7a9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace testing { + +class HloMatcher : public ::testing::MatcherInterface { + public: + HloMatcher(HloOpcode opcode, + std::vector<::testing::Matcher> operands) + : opcode_(opcode), operands_(operands) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + void DescribeTo(::std::ostream* os) const override; + + private: + HloOpcode opcode_; + std::vector<::testing::Matcher> operands_; +}; + +// HloInstruction* matchers for opcode and operands. Example: +// namespace op = xla::opcode_matchers; +// EXPECT_THAT(instruction, +// op::Add(op::Reshape(), op::Add(op::Reshape(), _))); +namespace opcode_matchers { +#define HLO_MATCHER(opcode) \ + template \ + ::testing::Matcher opcode(M... operands) { \ + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \ + ::xla::HloOpcode::k##opcode, {operands...})); \ + } +HLO_MATCHER(Abs); +HLO_MATCHER(Add); +HLO_MATCHER(Bitcast); +HLO_MATCHER(Broadcast); +HLO_MATCHER(Call); +HLO_MATCHER(Ceil); +HLO_MATCHER(Clamp); +HLO_MATCHER(Concatenate); +HLO_MATCHER(Constant); +HLO_MATCHER(Convert); +HLO_MATCHER(Convolution); +HLO_MATCHER(Copy); +HLO_MATCHER(CrossReplicaSum); +HLO_MATCHER(CustomCall); +HLO_MATCHER(Divide); +HLO_MATCHER(Dot); +HLO_MATCHER(DynamicSlice); +HLO_MATCHER(DynamicUpdateSlice); +HLO_MATCHER(Eq); +HLO_MATCHER(Exp); +HLO_MATCHER(Floor); +HLO_MATCHER(Fusion); +HLO_MATCHER(Ge); +HLO_MATCHER(GetTupleElement); +HLO_MATCHER(Gt); +HLO_MATCHER(Index); +HLO_MATCHER(Infeed); +HLO_MATCHER(IsFinite); +HLO_MATCHER(Le); +HLO_MATCHER(Log); +HLO_MATCHER(LogicalAnd); +HLO_MATCHER(LogicalNot); +HLO_MATCHER(LogicalOr); +HLO_MATCHER(Lt); +HLO_MATCHER(Map); +HLO_MATCHER(Maximum); +HLO_MATCHER(Minimum); +HLO_MATCHER(Multiply); +HLO_MATCHER(Ne); +HLO_MATCHER(Negate); +HLO_MATCHER(Outfeed); +HLO_MATCHER(Pad); +HLO_MATCHER(Parameter); +HLO_MATCHER(Power); +HLO_MATCHER(Recv); +HLO_MATCHER(Reduce); +HLO_MATCHER(ReduceWindow); +HLO_MATCHER(Remainder); +HLO_MATCHER(Reshape); +HLO_MATCHER(Reverse); +HLO_MATCHER(Rng); +HLO_MATCHER(Select); +HLO_MATCHER(SelectAndScatter); +HLO_MATCHER(Send); +HLO_MATCHER(Sign); +HLO_MATCHER(Slice); +HLO_MATCHER(Sort); +HLO_MATCHER(Subtract); +HLO_MATCHER(Tanh); +HLO_MATCHER(Trace); +HLO_MATCHER(Transpose); +HLO_MATCHER(Tuple); +HLO_MATCHER(Update); +HLO_MATCHER(While); +#undef HLO_MATCHER +} // namespace opcode_matchers + +// Helper to convert smart to raw pointers for matching. +template +std::vector Pointers(const Container& container) { + std::vector result; + result.reserve(container.size()); + for (const auto& entry : container) result.push_back(entry.get()); + return result; +} + +} // namespace testing + +// Tell GMock to print HloInstruction* by value, so error messages are nice. +// Has to be in the same namespace as 'HloInstruction'. +void PrintTo(const HloInstruction* inst, ::std::ostream* os) { + *os << (inst ? inst->ToString() : "nullptr"); +} + +void PrintTo(HloInstruction* inst, ::std::ostream* os) { + PrintTo(const_cast(inst), os); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1465d1cacdc971a04c620bc48bed33239a67a955 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; +using ::testing::Eq; + +namespace xla { +namespace { + +template +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(HloMatchersTest, Test) { + auto shape = ShapeUtil::MakeShape(F32, {1}); + auto param = HloInstruction::CreateParameter(0, shape, "param"); + auto mul = HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, + param.get(), param.get()); + auto add = HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param.get(), + mul.get()); + + EXPECT_THAT(add.get(), op::Add()); + EXPECT_THAT(add.get(), op::Add(op::Parameter(), op::Multiply())); + EXPECT_THAT(add.get(), + op::Add(op::Parameter(), op::Multiply(_, op::Parameter()))); + + // Negative matches: check the explanation string. + EXPECT_THAT(Explain(add.get(), op::Parameter()), Eq("")); + EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())), + Eq("has too many operands (got 2, want 1)")); + EXPECT_THAT( + Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), + Eq("\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" + "doesn't match expected:\n\t" + "parameter")); + EXPECT_THAT( + Explain(add.get(), + op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))), + Eq("\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" + "doesn't match expected:\n\t" + "multiply(add, add), \n" + "operand 0:\n\t" + "%param = f32[1]{0} parameter(0)\n" + "doesn't match expected:\n\t" + "add")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 36064e93fe8a750d05183d78738e92768506a835..cff9a6658d73dde6fbf0c754eb6df7cc7e9d6d16 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -31,20 +32,53 @@ limitations under the License. namespace xla { -HloComputation* HloModule::AddEntryComputation( +HloModule::HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle) + : name_(name), + config_(nullptr), + entry_computation_(nullptr), + has_entry_computation_handle_(true), + entry_computation_handle_(entry_computation_handle), + computation_name_uniquer_(/*separator=*/".") {} + +HloModule::HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config) + : name_(name), + config_(MakeUnique(config)), + entry_computation_(nullptr), + has_entry_computation_handle_(true), + entry_computation_handle_(entry_computation_handle), + computation_name_uniquer_(/*separator=*/".") {} + +HloModule::HloModule(const string& name) + : name_(name), + config_(nullptr), + entry_computation_(nullptr), + computation_name_uniquer_(/*separator=*/".") {} + +void HloModule::set_config(const HloModuleConfig& config) { + config_ = MakeUnique(config); +} + +HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation) { - CHECK_EQ(nullptr, entry_computation_); - entry_computation_ = computation.get(); + computation->UniquifyName(&computation_name_uniquer_); computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); } +HloComputation* HloModule::AddEntryComputation( + std::unique_ptr computation) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + return AddComputationInternal(std::move(computation)); +} + HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { - computation->set_parent(this); - computations_.push_back(std::move(computation)); - return computations_.back().get(); + return AddComputationInternal(std::move(computation)); } void HloModule::ReplaceComputations( diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d598750da657ab3d72c6c8689b6642ea5d7e602c..3efb9c72bb16249fbac5d7b84908305d003b31b4 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -41,19 +43,18 @@ namespace xla { // computations are owned by the module. class HloModule { public: - explicit HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle) - : name_(name), - entry_computation_(nullptr), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle); + + HloModule(const string& name, + const VersionedComputationHandle& entry_computation_handle, + const HloModuleConfig& config); // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. - explicit HloModule(const string& name) - : name_(name), entry_computation_(nullptr) {} + explicit HloModule(const string& name); // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -95,6 +96,14 @@ class HloModule { // computation B, then A will appear after B in the sort. std::list MakeComputationPostOrder() const; + bool has_config() const { return config_ != nullptr; } + + void set_config(const HloModuleConfig& config); + + const HloModuleConfig& config() const { return *config_; } + + HloModuleConfig* mutable_config() { return config_.get(); } + string ToString() const; // Outlines the given expression from the given computation. @@ -110,8 +119,17 @@ class HloModule { // Returns a randomly generated uint64. uint64 RandomNew64() const; + // Returns the unique name for a computation in this module. + string GetUniqueCompuationName(const string& prefix) { + return computation_name_uniquer_.GetUniqueName(prefix); + } + private: + HloComputation* AddComputationInternal( + std::unique_ptr computation); + const string name_; + std::unique_ptr config_; HloComputation* entry_computation_; std::vector> computations_; @@ -125,6 +143,9 @@ class HloModule { // Versioned handle of the entry computation of the module. bool has_entry_computation_handle_ = false; VersionedComputationHandle entry_computation_handle_; + + // Unique name generator for computation names, which are unique per module. + NameUniquer computation_name_uniquer_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 0f4252522d3c021ee5d95d1713167b2fb0fb1d69..1175be4f5082401483767ba02b83a8cec68605dd 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -61,7 +61,8 @@ TEST_F(HloModuleTest, OneComputationPostOrder) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(CreateConstantComputation()); - EXPECT_EQ(module->MakeComputationPostOrder().front(), computation); + EXPECT_THAT(module->MakeComputationPostOrder(), + ::testing::ElementsAre(computation)); } TEST_F(HloModuleTest, TwoComputationsPostOrder) { @@ -71,9 +72,13 @@ TEST_F(HloModuleTest, TwoComputationsPostOrder) { auto computation2 = module->AddEmbeddedComputation(CreateConstantComputation()); - EXPECT_MATCH( - testing::ListToVec(module->MakeComputationPostOrder()), - testing::UnorderedMatcher(computation1, computation2)); + EXPECT_THAT(module->MakeComputationPostOrder(), + ::testing::UnorderedElementsAre(computation1, computation2)); + + // We specified the same name for both computations, but the HloModule should + // have made the names unique. + EXPECT_EQ(computation1->name(), "Constant"); + EXPECT_EQ(computation2->name(), "Constant.1"); } TEST_F(HloModuleTest, DiamondComputationsPostOrder) { @@ -89,9 +94,9 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { CreateCallComputation({computation2, computation3})); auto post_order = module->MakeComputationPostOrder(); - EXPECT_MATCH(testing::ListToVec(post_order), - testing::UnorderedMatcher( - computation1, computation2, computation3, computation4)); + EXPECT_THAT(post_order, + ::testing::UnorderedElementsAre(computation1, computation2, + computation3, computation4)); EXPECT_EQ(post_order.back(), computation4); EXPECT_EQ(post_order.front(), computation1); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 616b239a9310bc13e14c861184b7efebe7da6b2f..ceb0cdaa3169bb57e4ebb61ac1b2ea41f1ef7995 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) { } } +bool HloOpcodeIsVariadic(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + return true; + default: + return false; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 978ed5e79b90c3c12f31b4d4e3d3314849fed75c..e2cdbfdfa7a4b5509dccf9a83ffbd799f9ab1374 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { // Returns true iff the given opcode is a comparison operation. bool HloOpcodeIsComparison(HloOpcode opcode); +// Returns true iff the given opcode has variadic operands. +bool HloOpcodeIsVariadic(HloOpcode opcode); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 0b64c16fdc6639a0288b4a69698a600b09ba32f7..892c89f9df209f2e39005a4901feae6699ce4d0b 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index b3168ed40ece3ea65c6b26b96250f2ea77969953..d1ef8cb6918d02287912b76b213ed2acd7940d76 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -34,15 +34,95 @@ limitations under the License. namespace xla { -PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) - : module_(module) {} +namespace { + +// Returns the nearest call graph ancestors of instructions 'a' and 'b' for +// which the ancestors are in the same computation. An instruction is an call +// graph ancestor of 'a' if the instruction calls the computation containing 'a' +// either directly or transitively. Degeneratively an instruction is an ancestor +// of itself. nullptr is returned if there is no common ancestor or if the +// caller chain of 'a' or 'b' diverges (has multiple callers) before the nearest +// common ancestor. +// +// Example: +// +// Entry computation: +// %x = Call(A, {Constant(42.0)}) +// %y = Call(B, {%x}) +// +// Computation A: +// %a = Negate(Param()) +// +// Computation B: +// %b = Exp(Param()); +// +// If called with %a and %b, this function would return (%x, %y). %x is an +// ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same +// computation. +std::pair +GetNearestCallGraphAncestorsInSameComputation(const HloInstruction* a, + const HloInstruction* b, + const CallGraph& call_graph) { + // Lambda which returns the next instruction in the callee->caller chain in + // the call graph. This is the unique instruction which calls the computation + // containing 'instruction'. If more than one instruction calls the + // computation containing 'instruction' or no instructions call the + // computation then nullptr is returned. + auto next_caller = + [&call_graph]( + const HloInstruction* instruction) -> const HloInstruction* { + const CallGraphNode& node = call_graph.GetNode(instruction->parent()); + if (node.caller_callsites().size() != 1) { + return nullptr; + } + return node.caller_callsites()[0].instruction(); + }; + + // Iterate through the callee->caller chains and find the earliest common + // element. + for (const HloInstruction* a_ancestor = a; a_ancestor != nullptr; + a_ancestor = next_caller(a_ancestor)) { + for (const HloInstruction* b_ancestor = b; b_ancestor != nullptr; + b_ancestor = next_caller(b_ancestor)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + } + } + return {nullptr, nullptr}; +} + +} // namespace -bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { +bool HloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // 'a' and 'b' may be in different computations. In this case, find the + // callgraph ancestor instructions which call (potentially transitively) the + // computations containing 'a' and 'b' and use these ancestor instructions to + // compare order. + const HloInstruction* a_ancestor; + const HloInstruction* b_ancestor; + std::tie(a_ancestor, b_ancestor) = + GetNearestCallGraphAncestorsInSameComputation(a, b, *call_graph_); + + if (a_ancestor == nullptr) { + // Ancestors in a common computation could not be found so consider the + // instructions 'a' and 'b' to be unordered. return false; } + // a_ancestor and b_ancestor must be either both null or both non-null. + CHECK_NE(b_ancestor, nullptr); + CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); +} + +PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) + : HloOrdering(module) {} + +bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); + // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. return strict_predecessors_.at(b->parent())->IsReachable(b, a); } @@ -86,7 +166,7 @@ string DependencyHloOrdering::ToString() const { SequentialHloOrdering::SequentialHloOrdering( const HloModule* module, const HloModuleSequence& module_sequence) - : module_(module), module_sequence_(module_sequence) { + : HloOrdering(module), module_sequence_(module_sequence) { // Create a map from instruction to its order position. for (auto computation_order : module_sequence_) { const std::vector& order = computation_order.second; @@ -97,12 +177,9 @@ SequentialHloOrdering::SequentialHloOrdering( } } -bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { - return false; - } +bool SequentialHloOrdering::ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const { + CHECK_EQ(a->parent(), b->parent()); // If either instruction is not in the order, then 'a' and 'b' are unordered. if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { return false; @@ -144,23 +221,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -namespace { -StatusOr MinimumMemoryForSequence( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), sequence, - computation, points_to_analysis, size_function)); - return result.heap_size; -} -} // namespace - StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { @@ -172,17 +232,16 @@ StatusOr MinimumMemoryForSequence( TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); - int64 total_memory = 0; - for (const auto& pair : module_sequence) { - const HloComputation* computation = pair.first; - const std::vector& sequence = pair.second; - TF_ASSIGN_OR_RETURN( - const int64 memory, - MinimumMemoryForSequence(*computation, sequence, *points_to_analysis, - size_function)); - total_memory += memory; - } - return total_memory; + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; } namespace { @@ -284,7 +343,7 @@ class ListScheduler { return freed_bytes; } - // Construct the scheduling priority of the given instruciton. + // Construct the scheduling priority of the given instruction. Priority GetPriority(const HloInstruction* instruction) { return {BytesFreedIfScheduled(instruction), instruction->user_count()}; } @@ -439,6 +498,18 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -446,13 +517,17 @@ StatusOr> CreateMemoryMinimizingSequence( // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, ListScheduler::Run(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, - MinimumMemoryForSequence(computation, list_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; TF_ASSIGN_OR_RETURN( @@ -460,8 +535,8 @@ StatusOr> CreateMemoryMinimizingSequence( RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, - MinimumMemoryForSequence(computation, dfs_sequence, points_to_analysis, - size_function)); + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; if (list_memory <= dfs_memory) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index e964c4c51ae14f89d1f1b0450990cfc50c8a74be..d2db18be0009b1ca62b538d3975e1a0a105c5e83 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -36,13 +37,13 @@ namespace xla { // buffers. class HloOrdering { public: - HloOrdering() = default; + HloOrdering(const HloModule* module) + : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. - virtual bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const = 0; + bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. @@ -50,6 +51,21 @@ class HloOrdering { const HloComputation& computation) const = 0; virtual string ToString() const = 0; + + protected: + // Returns true if instruction 'a' executes before instruction 'b'. + // Precondition: 'a' and 'b' are in the same computation. + // + // Derived classes should implement this method for determining order of + // instructions in the same comptuation. ExecutesBefore() analyzes the + // callgraph and uses this method to determine ordering of instructions in + // different computations. + virtual bool ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const = 0; + + const HloModule* module_; + + std::unique_ptr call_graph_; }; // Base class for partial orderings implemented by a map of strict predecessors @@ -58,11 +74,6 @@ class PredecessorHloOrdering : public HloOrdering { public: ~PredecessorHloOrdering() override = default; - // Returns true if instruction 'a' executes before instruction 'b'. - // Instructions in different computations are not ordered. - bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const override; - // Returns nullptr indicating the computation does not have a sequential // ordering. const std::vector* SequentialOrder( @@ -74,11 +85,12 @@ class PredecessorHloOrdering : public HloOrdering { explicit PredecessorHloOrdering(const HloModule* module); string ToStringHelper(const string& name) const; - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; - // For each each computation in the module, this is the set of the - // instruction's strict predecessors. An instruction is not an element of its - // own strict predecessor set. + // For each computation in the module, this is the set of the instruction's + // strict predecessors. An instruction is not an element of its own strict + // predecessor set. // // Subclasses should fill this in to define the desired ordering. tensorflow::gtl::FlatMap* SequentialOrder( const HloComputation& computation) const override; @@ -163,7 +169,9 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: - const HloModule* module_; + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; + const HloModuleSequence module_sequence_; // The position of every instruction in the HLO module in its respective diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 425bee601a8d6357e21d3d00f8ccf5d69af03862..c387fbb89b196c340852db057754f85e3e5435f3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -78,6 +78,142 @@ TEST_F(HloOrderingTest, LastUseScheduledFirst) { EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } +TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { + // Tests the ordering of instructions in different computations using the + // following HLO code: + // + // Entry computation: + // %x = Call(A, {}) + // %y = Call(B, {%x}) + // + // Computation A: + // %a = Call(C, {}) + // + // Computation B: + // %b = Call(C, {}) + // + // Computation C: + // %c = Constant(42.0f) + // + // This results in a diamond-shaped callgraph. + HloModule module(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder_c = HloComputation::Builder("C"); + HloInstruction* c = builder_c.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloComputation* computation_c = + module.AddEmbeddedComputation(builder_c.Build()); + + auto builder_b = HloComputation::Builder("B"); + builder_b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* b = builder_b.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_b = + module.AddEmbeddedComputation(builder_b.Build()); + + auto builder_a = HloComputation::Builder("A"); + HloInstruction* a = builder_a.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_c)); + HloComputation* computation_a = + module.AddEmbeddedComputation(builder_a.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* x = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {}, computation_a)); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {x}, computation_b)); + module.AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(&module); + EXPECT_TRUE(ordering.ExecutesBefore(x, y)); + EXPECT_FALSE(ordering.ExecutesBefore(y, x)); + + EXPECT_TRUE(ordering.ExecutesBefore(a, b)); + EXPECT_FALSE(ordering.ExecutesBefore(b, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(a, x)); + EXPECT_TRUE(ordering.ExecutesBefore(a, y)); + EXPECT_FALSE(ordering.ExecutesBefore(x, a)); + EXPECT_FALSE(ordering.ExecutesBefore(y, a)); + + EXPECT_FALSE(ordering.ExecutesBefore(b, x)); + EXPECT_FALSE(ordering.ExecutesBefore(b, y)); + EXPECT_TRUE(ordering.ExecutesBefore(x, b)); + EXPECT_FALSE(ordering.ExecutesBefore(y, b)); + + // Instruction 'c' is called from multiple callsites and should be unordered + // relative to all other instructions in the module. + EXPECT_FALSE(ordering.ExecutesBefore(c, a)); + EXPECT_FALSE(ordering.ExecutesBefore(c, b)); + EXPECT_FALSE(ordering.ExecutesBefore(c, x)); + EXPECT_FALSE(ordering.ExecutesBefore(c, y)); + EXPECT_FALSE(ordering.ExecutesBefore(a, c)); + EXPECT_FALSE(ordering.ExecutesBefore(b, c)); + EXPECT_FALSE(ordering.ExecutesBefore(x, c)); + EXPECT_FALSE(ordering.ExecutesBefore(y, c)); +} + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + HloModule module(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module.AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module.AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e3c983071245c548914bd9eecd0d1e86bc64d99..78aebe9c36dfb5f63099f5e2df7bffe8529b08de 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -40,11 +40,19 @@ void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module, } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { + run_called_ = true; + + VLOG(1) << "Running HLO pass pipeline " << name(); + legacy_flags::HloPassPipelineFlags* flags = legacy_flags::GetHloPassPipelineFlags(); std::vector tmp = tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ','); tensorflow::gtl::FlatSet disabled_passes(tmp.begin(), tmp.end()); + if (!disabled_passes.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << tensorflow::str_util::Join(disabled_passes, ", "); + } auto run_invariant_checkers = [this, module]() -> Status { for (auto& invariant_checker : invariant_checkers_) { @@ -60,9 +68,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { for (auto& pass : passes_) { if (!disabled_passes.empty() && disabled_passes.count(pass->name().ToString()) > 0) { + VLOG(1) << " Skipping HLO pass " << pass->name() + << ", disabled by --xla_disable_hlo_passes"; continue; } + VLOG(1) << " HLO pass " << pass->name(); + // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a8c2d518730b9fab8febaae35797ea4a315ab9b1..682c4b952df6aae8cb933c222772dbd823070ecc 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -47,6 +47,7 @@ class HloPassPipeline : public HloPassInterface { // Returns a reference to the added pass. template T& AddPass(Args&&... args) { + CHECK(!run_called_) << "AddPass cannot be called after Run"; auto pass = new T(std::forward(args)...); passes_.push_back(std::unique_ptr(pass)); return *pass; @@ -57,6 +58,7 @@ class HloPassPipeline : public HloPassInterface { // (it is required to always return "false" from its Run() method). template T& AddInvariantChecker(Args&&... args) { + CHECK(!run_called_) << "AddInvariantChecker cannot be called after Run"; auto pass = new T(std::forward(args)...); invariant_checkers_.push_back(std::unique_ptr(pass)); return *pass; @@ -70,6 +72,7 @@ class HloPassPipeline : public HloPassInterface { Compiler::HloDumper dumper_; std::vector> passes_; std::vector> invariant_checkers_; + bool run_called_ = false; TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline); }; diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 1556d1772f934ea02506aff27396034814d61698..a153d73dbd838663c0d7e0d72ad54668f243f2c2 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -32,6 +32,16 @@ bool IsConstantR0F32(HloInstruction* instruction, float* out) { return false; } +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kParameter && + operand->opcode() != HloOpcode::kConstant) { + return false; + } + } + return true; +} + bool AllOperandsAreParameters(const HloInstruction& instruction) { for (const auto& operand : instruction.operands()) { if (operand->opcode() != HloOpcode::kParameter) { @@ -41,6 +51,15 @@ bool AllOperandsAreParameters(const HloInstruction& instruction) { return true; } +bool AllOperandsAreConstants(const HloInstruction& instruction) { + for (const auto& operand : instruction.operands()) { + if (operand->opcode() != HloOpcode::kConstant) { + return false; + } + } + return true; +} + HloInstruction* GetMatchingOperand( std::function matcher, HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index 864f892e92047e6f39b2949854190522b2f4a906..c79347bbf9d6146943b7b787f713369cb37fadee 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -28,9 +28,16 @@ namespace hlo_query { // Precondition: out != nullptr bool IsConstantR0F32(HloInstruction* instruction, float* out); +// Returns whether all of an instruction's operands are of the types constants +// and parameters. +bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction); + // Returns whether all of an instruction's operands are parameters. bool AllOperandsAreParameters(const HloInstruction& instruction); +// Returns whether all of an instruction's operands are constants. +bool AllOperandsAreConstants(const HloInstruction& instruction); + // Returns whether the instruction is a scalar constant. bool IsScalarConstant(const HloInstruction* instruction); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 52a0181029ddb7eb373bb9e9f91e2899c3140c71..5d4fd7c2deae7e1b03f49f123e2aff174ab34667 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -22,14 +22,15 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -45,63 +46,58 @@ namespace xla { namespace { -// Returns a vector of the operands of 'instruction' with repeated elements -// removed. -std::vector UniqueOperands(const HloInstruction* instruction) { - std::vector unique_operands; - for (HloInstruction* operand : instruction->operands()) { - if (std::find(unique_operands.begin(), unique_operands.end(), operand) == - unique_operands.end()) { - unique_operands.push_back(operand); - } - } - return unique_operands; -} - // Returns true if the given instruction is rematerializable. bool IsRematerializable(const HloInstruction* instruction) { + // Conservatively, don't rematerialize instruction with control + // dependencies. For one, control dependencies are added to prevent + // interference of aliased buffers (say, in while bodies) and + // rematerialization is ignorant of liveness and may break the intended + // ordering. + if (!instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + return false; + } + // Don't rematerialize instructions with side effects, those with a cost that // might not be captured by HloCostAnalysis, or instructions which cannot be // cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kOutfeed: case HloOpcode::kInfeed: + case HloOpcode::kParameter: case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - break; - } - - // Skip tuple shapes because we do not currently account for buffer aliasing - // properly which results in improperly accounting of rematerialization cost - // for these shapes. - if (ShapeUtil::IsTuple(instruction->shape())) { - return false; - } - for (auto* operand : instruction->operands()) { - if (ShapeUtil::IsTuple(operand->shape())) { - return false; - } + return true; } - - return true; } -// Class which maintains an ordered list of instructions with fast insertion and -// removal of arbitrary elements. +// Class which maintains an ordered list of instructions with fast insertion +// before arbitrary elements. class InstructionList { public: explicit InstructionList(const std::vector order) { + int64 position = 0; for (const HloInstruction* inst : order) { instructions_.push_back(const_cast(inst)); instruction_iterators_.insert({const_cast(inst), std::next(instructions_.end(), -1)}); + // Initially position numbers are uniquely assigned in order. Later as + // instructions are added with InsertBefore* methods, some instructions + // may have duplicate position numbers, but the values will be guaranteed + // to be monotonically increasing through the list, and so is still useful + // for quickly(-ish) determining the order of arbitrary instructions in + // the list. + position_number_[inst] = position; + first_at_position_[position] = inst; + position++; } } @@ -110,22 +106,63 @@ class InstructionList { return instructions_; } - // Insert instruction 'to_insert' before instruction 'before' in the list. - Status InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + // Insert instruction 'to_insert' immediately before instruction 'before' in + // the list. + void InsertBefore(HloInstruction* to_insert, HloInstruction* before) { + VLOG(3) << "InsertBefore: " << to_insert->name() << " before " + << before->name(); auto it = instruction_iterators_.find(before); - TF_RET_CHECK(it != instruction_iterators_.end()); + CHECK(it != instruction_iterators_.end()); instruction_iterators_.insert( {to_insert, instructions_.insert(it->second, to_insert)}); - return Status::OK(); + // Assign the same position number to the newly added instruction as + // 'before'. This guarantees monotonicity of the position numbers, but not + // uniqueness. + int64 pos = position_number_.at(before); + position_number_[to_insert] = pos; + if (first_at_position_.at(pos) == before) { + first_at_position_[pos] = to_insert; + } } - // Removes instruction from the list. - Status Remove(HloInstruction* instruction) { - auto it = instruction_iterators_.find(instruction); - TF_RET_CHECK(it != instruction_iterators_.end()); - instructions_.erase(it->second); - instruction_iterators_.erase(it); - return Status::OK(); + // Insert instruction 'to_insert' immediately before the earliest instruction + // in 'before_instructions'. + void InsertBeforeInstructions( + HloInstruction* to_insert, + tensorflow::gtl::ArraySlice before_instructions) { + VLOG(3) << "InsertBeforeInstructions: " << to_insert->name() << " before {" + << tensorflow::str_util::Join( + before_instructions, ", ", + [](string* out, HloInstruction* inst) { + tensorflow::strings::StrAppend(out, inst->name()); + }) + << "}"; + + // Find the minimal position number of any instruction in + // 'before_instructions'. + CHECK(!before_instructions.empty()); + int64 min_position_number = std::numeric_limits::max(); + for (const HloInstruction* instruction : before_instructions) { + min_position_number = + std::min(min_position_number, position_number_.at(instruction)); + } + + // Because more than one instruction in 'before_instructions' may have a + // position number of 'min_position_number', find the first such instruction + // with position number 'min_position_number'. + for (auto it = instruction_iterators_.at( + first_at_position_.at(min_position_number)); + it != instructions_.end() && + position_number_.at(*it) == min_position_number; + ++it) { + if (std::find(before_instructions.begin(), before_instructions.end(), + *it) != before_instructions.end()) { + return InsertBefore(to_insert, *it); + } + } + LOG(FATAL) << "Expected to find instruction in before_instructions with " + "position number " + << min_position_number; } private: @@ -136,283 +173,630 @@ class InstructionList { tensorflow::gtl::FlatMap::iterator> instruction_iterators_; + + // A number assigned to each instruction which increases monotonically through + // 'instructions_'. Used to facilitate fast insertion of an instruction before + // the earliest instruction in a set of instructions + // (InsertBeforeInstructions) by enabling fast-ish ordering queries between + // instructions. If position_number_[a] < position_number_[b] then 'a' comes + // before 'b' in the list. If the position numbers are the same then nothing + // can be said about their order without examining the list. + // + // On object construction this value is precisely the instruction's ordinal + // position in the list. Instructions inserted via InsertBefore receive + // duplicate values. However, monotonicity is preserved. + tensorflow::gtl::FlatMap position_number_; + + // The first instruction in the list assigned a particular position number. + tensorflow::gtl::FlatMap first_at_position_; }; +// Return the HloInstructions which use the given LogicalBuffer. Sets +// has_indirect_users to whether any of the uses is indirect. A use is indirect +// if the instruction defining logical_buffer is not an operand of the use. This +// can happen via buffer aliasing (eg, tuples). +std::vector GetUsers( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, bool* has_indirect_users) { + std::vector users; + // To identify uses iterate through all HloInstruction users of the + // BufferAliases of the logical buffer. + *has_indirect_users = false; + for (const BufferAlias& buffer_alias : + points_to_analysis.GetBufferAliases(*logical_buffer)) { + for (const HloInstruction* user : buffer_alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(buffer_alias.instruction(), + buffer_alias.index(), user, + points_to_analysis)) { + // The alias may be an operand of 'user', but the LogicalBuffer cannot + // possibly be used by the instruction so ignore 'user'. This is the + // case, for example, for the tuple element buffers in a GetTupleElement + // instruction (the GTE instruction only uses the pointer vector). + continue; + } + if (buffer_alias.instruction() != logical_buffer->instruction()) { + *has_indirect_users = true; + } + // A buffer may be used by the instruction via more than one alias. For + // example, a buffer which appears in more than one element of a tuple. + if (std::find(users.begin(), users.end(), user) == users.end()) { + users.push_back(user); + } + } + } + return users; +} + // Class for tracking memory usage of a computation as the instructions are -// placed sequentially. Memory usage is the sum of live values at the current -// point in the instruction sequence. +// placed sequentially. Memory usage is the sum of the sizes of live values +// (LogicalBuffers) at the current point in the instruction sequence. class MemoryUsageTracker { public: MemoryUsageTracker( const HloComputation* computation, - const HloRematerialization::ShapeSizeFunction& size_function) - : computation_(computation), size_function_(size_function) { - for (const std::unique_ptr& instruction : - computation->instructions()) { - // Initially only live-in values occupy memory. - if (IsLiveIn(instruction.get())) { - memory_usage_ += TotalSizeBytes(instruction->shape()); - } + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list); + + // Starts the placement of the given instruction. This adds the sizes of the + // LogicalBuffers defined by the instruction to the current memory + // usage. Placement is broken into two steps (BeginInstruction and + // EndInstruction) to accurately model memory usage. At BeginInstruction the + // memory for the output value(s) of the current instruction is allocated. At + // EndInstruction memory for dead operand(s) is freed. + Status BeginInstruction(const HloInstruction* instruction); + + // Finishes the placement of the current instruction. This frees any dead + // operands or dead result of the instruction. This must be called after + // each call to BeginInstruction. + Status EndInstruction(); + + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is rematerialized. + int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const; + + // Adjusts memory usage to account for the rematerialization of + // original_instruction for all remaining unplaced uses. The rematerialization + // is remat_instruction. This method should be called after the HLO graph has + // been transformed (rematerialization instruction created and connected to + // uses). + Status AddRematerializedInstruction(HloInstruction* original_instruction, + HloInstruction* remat_instruction); + + // Returns whether the given instruction has been placed (BeginInstruction + // has been called with 'instruction' as the argument). + bool IsPlaced(const HloInstruction* instruction) const { + return ContainsKey(placed_instructions_, instruction); + } + + // Returns the current memory usage. This is the sum of sizes of all live + // values. + int64 memory_usage() const { return memory_usage_; } + + // Returns the current instruction being placed. + const HloInstruction* in_progress_instruction() const { + return in_progress_instruction_; + } + + // Check invariants of the data structure. This is expensive to call. + bool Check() const; + + string ToString() const; + + private: + // Type holding a unique identifier for each Buffer object. + using BufferId = int64; + + // A Buffer represents a single LogicalBuffer in the computation including + // various metadata useful for tracking liveness of the value. A LogicalBuffer + // is not used directly because the HLO graph is transformed and + // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after + // HLO graph transformations. + struct Buffer { + // The unique id of this Buffer. This value is equal to the buffer's index + // in the vector buffers_. + const BufferId id; + + // The instruction which defines this buffer. + const HloInstruction* defining_instruction; + + // The materialized size of the buffer in bytes. + const int64 size; + + // Whether this buffer is live-out of the computation. + bool live_out; + + // Whether this buffer has indirect uses. Ie, an instruction which is not a + // user of defining_instruction uses this buffer. This can occur due to + // buffer aliasing (eg, tuples). + bool has_indirect_uses; + + // The instructions which use this buffer. + std::vector users; + + // The number of users (HloInstructions) of this buffer which have not yet + // been placed in the sequence. + int64 unfinished_user_count; + + string ToString() const { + return tensorflow::strings::StrCat("Buffer ", id, " (defined by ", + defining_instruction->name(), + ", size ", size, " bytes)"); } + }; + + // Creates a Buffer representing the given logical buffer. The buffer is added + // to buffers_ and a reference is returned. + Buffer& CreateBufferFromLogicalBuffer( + const LogicalBuffer* logical_buffer, + const TuplePointsToAnalysis& points_to_analysis, + const HloRematerialization::ShapeSizeFunction& size_function, + bool live_out) { + bool has_indirect_uses = false; + std::vector users = + GetUsers(logical_buffer, points_to_analysis, &has_indirect_uses); + return NewBuffer(logical_buffer->instruction(), + size_function(logical_buffer->shape()), std::move(users), + live_out, has_indirect_uses); } - // Starts the placement of the given instruction. This adds the output size of - // the instruction to the current memory usage. Placement is broken into two - // steps (BeginInstruction and EndInstruction) to accurately model memory - // usage. At BeginInstruction the memory for the output value of the current - // instruction is allocated. At EndInstruction memory for dead operands is - // freed. - Status BeginInstruction(const HloInstruction* instruction) { - VLOG(3) << "BeginInstruction " << instruction->name(); - TF_RET_CHECK(in_progress_instruction_ == nullptr); - in_progress_instruction_ = instruction; - - // Add instruction to remaining_uses_. - TF_RET_CHECK(!ContainsKey(remaining_uses_, instruction)); - std::vector& instruction_uses = - remaining_uses_[instruction]; - instruction_uses.insert(instruction_uses.begin(), - instruction->users().begin(), - instruction->users().end()); - - if (!IsLiveIn(instruction)) { - // Instruction was not previously live so add output size to memory usage. - memory_usage_ += TotalSizeBytes(instruction->shape()); + // Create a new buffer representing a rematerialization of given buffer for + // the given uses. + Buffer& RematerializeBuffer( + const Buffer& original_buffer, const HloInstruction* remat_instruction, + std::vector&& rematerialized_uses) { + CHECK(IsPlaced(original_buffer.defining_instruction)); + CHECK(!original_buffer.has_indirect_uses); + CHECK(!original_buffer.live_out); + for (const HloInstruction* use : rematerialized_uses) { + CHECK(!IsPlaced(use)); } + return NewBuffer(remat_instruction, original_buffer.size, + std::move(rematerialized_uses), /*live_out=*/false, + /*has_indirect_uses=*/false); + } + + // Return number of bytes allocated for the buffer with the given id. Buffers + // allocated by the calling computation (eg, parameter and output buffers) are + // considered to have zero bytes because the memory is accounted for in a + // different computation. + int64 AllocatedSize(BufferId buffer_id) const { + const Buffer& buffer = buffers_.at(buffer_id); + HloOpcode def_opcode = buffer.defining_instruction->opcode(); + if (buffer.live_out || def_opcode == HloOpcode::kParameter) { + return 0; + } else { + return buffer.size; + } + } - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // Returns true if BeginInstruction and EndInstruction has been called for the + // given instruction. + bool IsFinished(const HloInstruction* instruction) const { + return IsPlaced(instruction) && instruction != in_progress_instruction_; } - // Finishes the placement of the current instruction. This frees any dead - // operands or dead result of the instruction. This must be called after each - // call to BeginInstruction. - Status EndInstruction() { - TF_RET_CHECK(in_progress_instruction_ != nullptr); - VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); - - for (HloInstruction* operand : UniqueOperands(in_progress_instruction_)) { - TF_RET_CHECK(ContainsKey(remaining_uses_, operand)); - std::vector& uses = remaining_uses_.at(operand); - auto it = std::find(uses.begin(), uses.end(), in_progress_instruction_); - TF_RET_CHECK(it != uses.end()); - uses.erase(it); - - if (uses.empty()) { - // Operand is dead. - int64 operand_size = TotalSizeBytes(operand->shape()); - if (!IsLiveOut(operand)) { - VLOG(4) << operand->name() << " (" - << HumanReadableNumBytes(operand_size) << ") is dead"; - memory_usage_ -= operand_size; - TF_RET_CHECK(memory_usage_ >= 0); + // Returns whether the given buffer is being used by the in-progress + // instruction. + bool IsInUse(BufferId buffer_id) const { + if (in_progress_instruction_ == nullptr) { + return false; + } + const std::vector& in_progress_uses = + buffers_used_by_instruction_.at(in_progress_instruction_); + return std::find(in_progress_uses.begin(), in_progress_uses.end(), + buffer_id) != in_progress_uses.end(); + } + + // Returns whether the given instruction is live at the current program + // point. + bool IsCurrentlyLive(BufferId buffer_id) const { + const Buffer& buffer = buffers_[buffer_id]; + return (IsPlaced(buffer.defining_instruction) && + buffer.unfinished_user_count > 0); + } + + // Create a new buffer, add it to buffers_, and return a reference. + Buffer& NewBuffer(const HloInstruction* defining_instruction, int64 size, + std::vector&& users, bool live_out, + bool has_indirect_uses) { + int buffer_id = buffers_.size(); + buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, + has_indirect_uses, users, + static_cast(users.size())}); + return buffers_.back(); + } + + const HloComputation* computation_; + + // Instruction list containing the ordering of instructions in + // computation_. This is the order in which instructions are placed + // (BeginInstruction/EndInstruction calls). + const InstructionList& instruction_list_; + + // Memory usage at the currently placed instruction. + int64 memory_usage_ = 0; + + // The instruction currently being placed. This value is non-null only + // between the calling of BeginInstruction and EndInstruction. + const HloInstruction* in_progress_instruction_ = nullptr; + + // The buffers defined by each instruction. + std::unordered_map> + buffers_defined_by_instruction_; + + // The buffers used by each instruction. + std::unordered_map> + buffers_used_by_instruction_; + + // The set of instructions which have been placed. That is, BeginInstruction + // has been called with the instruction as an argument. + tensorflow::gtl::FlatSet placed_instructions_; + + // All buffers in the computation. + std::vector buffers_; +}; + +MemoryUsageTracker::MemoryUsageTracker( + const HloComputation* computation, + const HloRematerialization::ShapeSizeFunction& size_function, + const TuplePointsToAnalysis& points_to_analysis, + const InstructionList& instruction_list) + : computation_(computation), instruction_list_(instruction_list) { + // Iterate through all LogicalBuffers in the computation and gather the + // instructions which define them in buffers_defined_by_instruction_ and the + // instructions which use them in buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + // Initialize empty vectors for defs and uses of each instruction. + buffers_used_by_instruction_[instruction.get()]; + buffers_defined_by_instruction_[instruction.get()]; + } + + tensorflow::gtl::FlatSet live_out_set = + points_to_analysis.GetPointsToSet(computation_->root_instruction()) + .CreateFlattenedSet(); + tensorflow::gtl::FlatMap + logical_buffer_to_buffer_id; + + for (const HloInstruction* instruction : instruction_list_.instructions()) { + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { + Buffer* buffer; + if (instruction->opcode() == HloOpcode::kWhile) { + // The while instruction defines no new buffers. Instead it reuses the + // buffers of its operand. Find the Buffer of its operand at the + // proper ShapeIndex. + const PointsToSet& operand_points_to = + points_to_analysis.GetPointsToSet(instruction->operand(0)); + CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1); + const LogicalBuffer* source_logical_buffer = + operand_points_to.element(logical_buffer->index())[0]; + buffer = + &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer)); + + // Mark buffer as has indirect use and live out. + buffer->has_indirect_uses = true; + buffer->live_out = + buffer->live_out || ContainsKey(live_out_set, logical_buffer); + + // Add users of while to Buffer users. + bool unused; + for (const HloInstruction* user : + GetUsers(logical_buffer, points_to_analysis, &unused)) { + if (std::find(buffer->users.begin(), buffer->users.end(), user) == + buffer->users.end()) { + buffer->users.push_back(user); + buffer->unfinished_user_count++; + buffers_used_by_instruction_.at(user).push_back(buffer->id); + } + } + } else { + buffer = &CreateBufferFromLogicalBuffer( + logical_buffer, points_to_analysis, size_function, + ContainsKey(live_out_set, logical_buffer)); + buffers_defined_by_instruction_.at(instruction).push_back(buffer->id); + for (const HloInstruction* user : buffer->users) { + buffers_used_by_instruction_.at(user).push_back(buffer->id); } } - } - // Value is dead if the instruction has no uses and is not live out. - if (in_progress_instruction_->users().empty() && - !IsLiveOut(in_progress_instruction_)) { - memory_usage_ -= TotalSizeBytes(in_progress_instruction_->shape()); - TF_RET_CHECK(memory_usage_ >= 0); + logical_buffer_to_buffer_id[logical_buffer] = buffer->id; } + } + XLA_VLOG_LINES(10, ToString()); + DCHECK(Check()); +} + +Status MemoryUsageTracker::BeginInstruction(const HloInstruction* instruction) { + VLOG(3) << "BeginInstruction " << instruction->name(); + TF_RET_CHECK(in_progress_instruction_ == nullptr); + in_progress_instruction_ = instruction; - in_progress_instruction_ = nullptr; + placed_instructions_.insert(in_progress_instruction_); - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + // All buffers defined by this instruction need memory. + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString() + << " is now live."; + memory_usage_ += AllocatedSize(buffer_id); } - // Adjusts memory usage to account for the rematerialization of - // original_instruction for the given use. The rematerialization is - // remat_instruction. This method should be called after the HLO graph has - // been transformed (rematerialization instruction created and connected to - // its use). - Status RematerializeInstructionForUse(HloInstruction* original_instruction, - HloInstruction* remat_instruction, - HloInstruction* use) { - VLOG(3) << "RematerializeInstructionForUse: original_instruction = " - << original_instruction->name() - << ", remat_instruction = " << remat_instruction->name() - << ", use = " << use->name(); - - TF_RET_CHECK(in_progress_instruction_ != nullptr); - TF_RET_CHECK(IsPlaced(original_instruction)); - TF_RET_CHECK(!IsPlaced(remat_instruction)); - TF_RET_CHECK(!IsPlaced(use)); - TF_RET_CHECK(IsCurrentlyLive(original_instruction)); - - // Remove 'use' from remaining uses of original_instruction. - auto it = std::find(remaining_uses_[original_instruction].begin(), - remaining_uses_[original_instruction].end(), use); - TF_RET_CHECK(it != remaining_uses_[original_instruction].end()); - remaining_uses_[original_instruction].erase(it); - - // If original_instruction is no longer live ('use' was its last use) then - // deduct original_instruction's memory usage. - if (!IsCurrentlyLive(original_instruction)) { - memory_usage_ -= TotalSizeBytes(original_instruction->shape()); - TF_RET_CHECK(memory_usage_ >= 0); + // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead) + // operand. Account for this potential reuse here. + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); + return Status::OK(); +} + +Status MemoryUsageTracker::EndInstruction() { + TF_RET_CHECK(in_progress_instruction_ != nullptr); + VLOG(3) << "EndInstruction " << in_progress_instruction_->name(); + + for (BufferId buffer_id : + buffers_used_by_instruction_.at(in_progress_instruction_)) { + Buffer& buffer = buffers_.at(buffer_id); + buffer.unfinished_user_count--; + CHECK_GE(buffer.unfinished_user_count, 0) + << buffer.ToString() << " has negative unfinished use count."; + if (buffer.unfinished_user_count == 0) { + // Buffer is now dead. + VLOG(3) << " " << buffer.ToString() << " is now dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); } + } - // Add the new remat_instruction to the remaining uses of its operands. - for (auto* operand : UniqueOperands(remat_instruction)) { - // Rematerialization may extend the lifetime of the operand so account for - // this in memory_usage_. - TF_RET_CHECK(IsPlaced(operand)); - if (!IsCurrentlyLive(operand)) { - memory_usage_ += TotalSizeBytes(operand->shape()); - } - remaining_uses_.at(operand).push_back(remat_instruction); + // If any buffer defined by this instruction has no uses, then memory can be + // reclaimed immediately. + for (BufferId buffer_id : + buffers_defined_by_instruction_.at(in_progress_instruction_)) { + const Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + VLOG(3) << " " << buffer.ToString() << " is immediately dead."; + memory_usage_ -= AllocatedSize(buffer_id); + CHECK_GE(memory_usage_, 0); } + } + + in_progress_instruction_ = nullptr; + + VLOG(3) << " memory usage = " << memory_usage_; + VLOG(10) << ToString(); + + DCHECK(Check()); - VLOG(3) << " memory usage = " << memory_usage_; - VLOG(10) << ToString(); - return Status::OK(); + return Status::OK(); +} + +int64 MemoryUsageTracker::MemoryReducedIfRematerialized( + const HloInstruction* instruction) const { + CHECK_NE(in_progress_instruction_, nullptr); + if (!IsPlaced(instruction) || instruction == in_progress_instruction_) { + return 0; } - // Returns the number of bytes that the current memory usage will be reduced - // if the given instruction is rematerialized. - int64 MemoryReducedIfRematerialized(const HloInstruction* instruction) const { - // To reduce memory consumption 'instruction' must be currently live and - // rematerialization must make 'instruction' not live. - if (IsLiveIn(instruction) || IsLiveOut(instruction) || - !IsCurrentlyLive(instruction)) { + // TODO(b/37687140): Rematerialization can increase peak memory consumption at + // an earlier point in the program if rematerialization extends the live range + // of the operand of the instruction being rematerialized across the live + // range of the value of instruction being rematerialized. Don't rematerialize + // in this case (ie, return 0 here). + + // Compute the amount of memory reduced (if any) by rematerializing + // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer + // be live at this program point, so initially set memory_reduced to the + // size of its defined values. + int64 memory_reduced = 0; + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + // Avoid rematerializing instructions with indirect uses as it is difficult + // to reason about liveness after rematerializing the instruction. + // TODO(b/37714814): Consider rematerialzing instructions with indirect + // uses. + if (buffers_.at(buffer_id).has_indirect_uses) { return 0; } - // If the in-progress instruction is a user of 'instruction' (or - // 'instruction' itself) then rematerializing 'instruction' cannot reduce - // memory usage because the value is required to be live at this program - // point. - if (in_progress_instruction_ == instruction || - in_progress_instruction_->IsUserOf(instruction)) { - return 0; + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) { + memory_reduced += AllocatedSize(buffer_id); } + } - // Compute the amount of memory reduced (if any) by rematerializing - // 'instruction'. 'instruction' will no longer be live at this program - // point, so initially set memory_reduced to the size of its output value. - int64 memory_reduced = TotalSizeBytes(instruction->shape()); - - // Account for any operands whose live range must be extended across this - // program point. - for (const HloInstruction* operand : UniqueOperands(instruction)) { - if (!IsCurrentlyLive(operand)) { - // This operand of candidate is not live at this program - // point. Rematerializing 'instruction' will extend the operand's live - // range across this program point. - memory_reduced -= TotalSizeBytes(operand->shape()); - } + // Account for any logical buffers whose live range must be extended across + // this program point. + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + if (!IsCurrentlyLive(buffer_id)) { + // This logical buffer is used by 'instruction' but is not live at this + // program point. Rematerializing 'instruction' will extend the buffer's + // live range across this program point. + memory_reduced -= AllocatedSize(buffer_id); } - return memory_reduced; } - // Returns the remaining unplaced uses of the given instruction. - const std::vector& RemainingUses( - const HloInstruction* instruction) const { - return remaining_uses_.at(instruction); + return memory_reduced; +} + +Status MemoryUsageTracker::AddRematerializedInstruction( + HloInstruction* original_instruction, HloInstruction* remat_instruction) { + VLOG(3) << "AddRematerializedInstruction: original_instruction = " + << original_instruction->name() + << ", remat_instruction = " << remat_instruction->name(); + + TF_RET_CHECK(in_progress_instruction_ != nullptr); + TF_RET_CHECK(IsPlaced(original_instruction)); + TF_RET_CHECK(!IsPlaced(remat_instruction)); + CHECK(!ContainsKey(buffers_defined_by_instruction_, remat_instruction)); + CHECK(!ContainsKey(buffers_used_by_instruction_, remat_instruction)); + + // Construct the list of buffers used and defined by the rematerialization. + buffers_defined_by_instruction_[remat_instruction]; + buffers_used_by_instruction_[remat_instruction] = + buffers_used_by_instruction_.at(original_instruction); + + // Account for the additional buffer uses created by the new rematerialization + // instruction. Update memory usage if the rematerialization makes a dead + // buffer live again. + for (BufferId buffer_id : + buffers_used_by_instruction_.at(original_instruction)) { + Buffer& buffer = buffers_.at(buffer_id); + if (buffer.unfinished_user_count == 0) { + // Buffer used by this instruction was dead, now is alive. + memory_usage_ += AllocatedSize(buffer.id); + } + + buffer.unfinished_user_count++; + buffer.users.push_back(remat_instruction); } - // Returns whether the given instruction has been placed (BeginInstruction has - // been called with 'instruction' as the argument). - bool IsPlaced(const HloInstruction* instruction) const { - return ContainsKey(remaining_uses_, instruction); - } - - // Returns whether the given instruction is live at the current program point. - bool IsCurrentlyLive(const HloInstruction* instruction) const { - return (!IsPlaced(instruction) && IsLiveIn(instruction)) || - (IsPlaced(instruction) && - (!RemainingUses(instruction).empty() || IsLiveOut(instruction))); - } - - string ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend(&output, "memory usage = ", memory_usage(), - "\n"); - tensorflow::strings::StrAppend(&output, "Live values:\n"); - for (const auto& pair : remaining_uses_) { - const HloInstruction* instruction = pair.first; - const std::vector& uses = pair.second; - tensorflow::strings::StrAppend( - &output, " ", instruction->name(), "; remaining uses: ", - tensorflow::str_util::Join(uses, ", ", - [](string* out, HloInstruction* use) { - tensorflow::strings::StrAppend( - out, use->name()); - }), - "\n"); + // Create a new set of Buffers defined by the new rematerialization + // instruction. Update the internal data structures and memory use to account + // for them. + for (BufferId old_buffer_id : + buffers_defined_by_instruction_.at(original_instruction)) { + Buffer& old_buffer = buffers_.at(old_buffer_id); + + std::vector placed_users; + std::vector unplaced_users; + for (const HloInstruction* user : old_buffer.users) { + if (IsPlaced(user)) { + CHECK(IsFinished(user)); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + old_buffer.users = std::move(placed_users); + old_buffer.unfinished_user_count = 0; + + // Buffer is now dead. + memory_usage_ -= AllocatedSize(old_buffer.id); + + Buffer& new_buffer = RematerializeBuffer(old_buffer, remat_instruction, + std::move(unplaced_users)); + + buffers_defined_by_instruction_.at(remat_instruction) + .push_back(new_buffer.id); + for (const HloInstruction* user : new_buffer.users) { + std::vector& buffers_used = + buffers_used_by_instruction_.at(user); + std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id, + new_buffer.id); } - return output; } - // Returns the current memory usage. This is the sum of sizes of all live - // values. - int64 memory_usage() const { return memory_usage_; } + VLOG(3) << " memory usage = " << memory_usage_; + XLA_VLOG_LINES(10, ToString()); - // Returns the current instruction being placed. - const HloInstruction* in_progress_instruction() const { - return in_progress_instruction_; - } + DCHECK(Check()); - private: - // Returns the total size of the shape (including nested elements) in bytes. - int64 TotalSizeBytes(const Shape& shape) const { - int64 total_size = 0; - ShapeUtil::ForEachSubshape( - shape, - [this, &total_size](const Shape& subshape, - const ShapeIndex& /*index*/) { - total_size += size_function_(subshape); - return Status::OK(); - }) - .IgnoreError(); - return total_size; - } - - // Returns true if the value of given instruction is live into the - // computation. - bool IsLiveIn(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter; - } - - // Returns true if the value of given instruction is live out of the - // computation. - bool IsLiveOut(const HloInstruction* instruction) const { - return instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kParameter || - instruction == instruction->parent()->root_instruction(); + return Status::OK(); +} + +string MemoryUsageTracker::ToString() const { + string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", + computation_->name(), "\n"); + tensorflow::strings::StrAppend( + &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); + for (const HloInstruction* instruction : instruction_list_.instructions()) { + string inprogress = + instruction == in_progress_instruction_ ? " in-progress" : ""; + string placed = IsPlaced(instruction) ? " placed" : ""; + tensorflow::strings::StrAppend(&output, " ", instruction->name(), + inprogress, placed, "\n Defines:\n"); + for (BufferId buffer_id : buffers_defined_by_instruction_.at(instruction)) { + const Buffer& buffer = buffers_[buffer_id]; + string live = IsCurrentlyLive(buffer_id) ? " live" : ""; + tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, + ", ", buffer.unfinished_user_count, + " unfinished uses\n"); + } + tensorflow::strings::StrAppend(&output, " Uses:\n"); + for (BufferId buffer_id : buffers_used_by_instruction_.at(instruction)) { + tensorflow::strings::StrAppend(&output, " ", + buffers_[buffer_id].ToString(), "\n"); + } } + return output; +} - const HloComputation* computation_; +bool MemoryUsageTracker::Check() const { + auto elements_are_unique = [](const std::vector& vec) { + return vec.size() == std::set(vec.begin(), vec.end()).size(); + }; + + // Verify buffers_defined_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& defined_buffers = + buffers_defined_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(defined_buffers)) + << "Instruction " << instruction->name() + << " does not have unique defined buffers: " + << tensorflow::str_util::Join( + defined_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); - // Function which computes the size of the top-level buffer of a shape. - const HloRematerialization::ShapeSizeFunction size_function_; + for (const Buffer& buffer : buffers_) { + if (buffer.defining_instruction == instruction.get()) { + CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), + buffer.id) != defined_buffers.end()) + << "Instruction " << instruction->name() + << " defined buffers is missing: " << buffer.ToString(); + } + } + } - // Memory usage at the currently placed instruction. - int64 memory_usage_ = 0; + // Verify buffers_used_by_instruction_. + for (auto& instruction : computation_->instructions()) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(instruction.get()); + CHECK(elements_are_unique(used_buffers)) + << "Instruction " << instruction->name() + << " does not have unique used buffers: " + << tensorflow::str_util::Join( + used_buffers, ", ", [this](string* out, BufferId buffer_id) { + tensorflow::strings::StrAppend( + out, buffers_.at(buffer_id).ToString()); + }); + } + for (const Buffer& buffer : buffers_) { + int64 unfinished_uses = 0; + for (const HloInstruction* user : buffer.users) { + const std::vector& used_buffers = + buffers_used_by_instruction_.at(user); + CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != + used_buffers.end()) + << "Instruction " << user->name() << " used buffers is missing " + << buffer.ToString(); + if (!IsFinished(user)) { + unfinished_uses++; + } + } + CHECK_EQ(buffer.unfinished_user_count, unfinished_uses) + << "Incorrect unplaced use count for " << buffer.ToString(); + } - // The instruction currently being placed. This value is non-null only between - // the calling of BeginInstruction and EndInstruction. - const HloInstruction* in_progress_instruction_ = nullptr; + // Verify live set size against memory_usage_. + int64 live_size = 0; + for (const Buffer& buffer : buffers_) { + // The while instruction reuses its input buffers as output buffers so + // don't double count its buffers if it is currently executing. + if (IsCurrentlyLive(buffer.id) && + !(buffer.defining_instruction == in_progress_instruction_ && + in_progress_instruction_->opcode() == HloOpcode::kWhile)) { + live_size += AllocatedSize(buffer.id); + } + } + CHECK_EQ(live_size, memory_usage_); - // remaining_uses is a vector of uses of the HLO instruction's value which - // have not yet been visited by in the rematerialization loop. Use to track - // liveness of HLO instructions. - // TODO(b/35212854): Track values using logical buffers rather than HLO - // instructions. Using HLO instructions over-estimates memory usage because - // buffer aliasing is ignored. - tensorflow::gtl::FlatMap> - remaining_uses_; -}; + return true; +} -// Computes and returns the cost of rematerializing the given instruction. Cost -// per rematerialized instruction is defined as: +// Computes and returns the cost of rematerializing the given instruction. +// Cost per rematerialized instruction is defined as: // // (flop_count + transcendental_count + element_count) / memory_reduced // @@ -424,33 +808,36 @@ class MemoryUsageTracker { // instruction. // // This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we want -// the most memory saving for the least latency penalty which is captured by -// this heuristic. +// rematerializing this instruction for its remaining uses. In general, we +// want the most memory saving for the least latency penalty which is captured +// by this heuristic. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, const HloCostAnalysis& cost_analysis, int64 memory_reduced) { - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - bytes_accessed / - ShapeUtil::ByteSizeOfPrimitiveType(instruction->shape().element_type()); - - // A duplicate of the rematerialized instruction will be created at each - // remaining use. - int64 duplication = memory_tracker.RemainingUses(instruction).size(); - if (duplication == instruction->users().size()) { - // All remaining uses of instruction are after this point so we can remove - // the original instruciton after rematerialization. - duplication -= 1; + // If none of the users of 'instruction' have been placed in the sequence (as + // tracked by memory_tracker), then rematerialization of 'instruction' is a + // zero-cost move of 'instruction' in the sequence. + if (!std::any_of(instruction->users().begin(), instruction->users().end(), + [&memory_tracker](const HloInstruction* inst) { + return memory_tracker.IsPlaced(inst); + })) { + return 0; } + CHECK_GT(memory_reduced, 0); + const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); + const int64 elements_accessed = + ShapeUtil::IsTuple(instruction->shape()) + ? bytes_accessed + : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( + instruction->shape().element_type()); // Multiply by 256 to improve precision of cost. Without this factor, // many instructions such as many elementwise instructions would have // zero cost because the bytes reduced can be several times greater than // the element count. - return 256 * duplication * + return 256 * (cost_analysis.flop_count(*instruction) + cost_analysis.transcendental_count(*instruction) + elements_accessed) / @@ -466,7 +853,7 @@ HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& remat_instructions) { + const tensorflow::gtl::FlatSet& blacklist) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -481,11 +868,11 @@ HloInstruction* PickRematerializationCandidate( } VLOG(5) << "considering rematerialization candidate " << candidate->name(); - if (ContainsKey(remat_instructions, candidate)) { - // Skip instructions which are rematerialization clones to avoid infinite - // loops of rematerializing the same instruction(s) repeatedly. + if (ContainsKey(blacklist, candidate)) { + // Skip instructions on the blacklist to avoid infinite loops of + // rematerializing the same instruction(s) repeatedly. VLOG(5) << "candidate " << candidate->name() - << " not viable: is a rematerialized instruction"; + << " is excluded from rematerialization"; continue; } @@ -524,7 +911,9 @@ HloInstruction* PickRematerializationCandidate( StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const std::vector& order) const { - MemoryUsageTracker tracker(computation, size_function_); + InstructionList instruction_list(order); + MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + instruction_list); int64 peak_memory = tracker.memory_usage(); for (const HloInstruction* instruction : order) { TF_RETURN_IF_ERROR(tracker.BeginInstruction(instruction)); @@ -541,9 +930,8 @@ StatusOr HloRematerialization::ComputePeakMemory( StatusOr HloRematerialization::CalledComputationsMemoryUsage( const HloInstruction* instruction) const { - TF_ASSIGN_OR_RETURN(const CallGraphNode* node, - call_graph_->GetNode(instruction->parent())); - const CallSite* callsite = node->GetCallSite(instruction); + const CallSite* callsite = + call_graph_->GetNode(instruction->parent()).GetCallSite(instruction); if (callsite == nullptr || callsite->context() == CallContext::kParallel) { return 0; } @@ -563,15 +951,24 @@ StatusOr HloRematerialization::RematerializeComputation( << " with limit " << HumanReadableNumBytes(memory_limit_bytes); VLOG(1) << "peak memory usage is " << HumanReadableNumBytes(computation_peak_memory_.at(computation)); + CHECK(!ContainsKey(rematerialized_computations_, computation)); InstructionList instruction_list(sequence->at(computation)); - MemoryUsageTracker memory_tracker(computation, size_function_); + MemoryUsageTracker memory_tracker(computation, size_function_, + *points_to_analysis_, instruction_list); bool changed = false; - // Set of instruction clones (not the originals) created during - // rematerialization. A record is kept to avoid rematerializing an instruction - // more than once to avoid looping infinitely during rematerialization. - tensorflow::gtl::FlatSet remat_instructions; + // To avoid an infinite loop rematerializing the same set of instructions ad + // infinitum, keep a blacklist of instructions which should not be + // rematerialized. + tensorflow::gtl::FlatSet blacklist; + + // If the rematerialization makes the source instruction dead, then the + // rematerialization is added to 'remat_move_instructions' (the + // rematerialization is essentially a move). If the next rematerialization of + // the instruction is also a move then the rematerialization is added to the + // blacklist. + tensorflow::gtl::FlatSet remat_move_instructions; // The peak memory of the computation at any point in the instruction // sequence. @@ -583,12 +980,12 @@ StatusOr HloRematerialization::RematerializeComputation( // instructions which are dead. int64 net_instructions_added = 0; - TF_ASSIGN_OR_RETURN(const CallGraphNode* call_graph_node, - call_graph_->GetNode(computation)); + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); // Iterate through all instructions in the sequence. At each instruction // (program point) if memory_usage exceeds the specified limit then // rematerialize HLO instructions until memory_usage is reduced. + int64 instruction_index = 0; for (auto list_it = instruction_list.instructions().begin(); list_it != instruction_list.instructions().end(); ++list_it) { HloInstruction* instruction = *list_it; @@ -598,7 +995,9 @@ StatusOr HloRematerialization::RematerializeComputation( VLOG(2) << "Program point at " << instruction->name() << ", memory usage = " << memory_tracker.memory_usage() - << ", callee usage = " << callee_usage; + << ", callee usage = " << callee_usage << ", [" << instruction_index + << "/" << instruction_list.instructions().size() << "]"; + instruction_index++; while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { VLOG(2) << "Over memory limit at instruction " << instruction->name() @@ -608,7 +1007,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, remat_instructions); + memory_tracker, instruction_list, cost_analysis_, blacklist); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -619,44 +1018,42 @@ StatusOr HloRematerialization::RematerializeComputation( break; } - VLOG(1) << "Rematerializing instruction " << best->name(); + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << memory_tracker.MemoryReducedIfRematerialized(best) << ")"; changed = true; remat_count++; - // Create a rematerialized copy of the candidate at each remaining use. - // Make a copy of remaining uses because RematerializeInstructionForUse - // modifies the remaining uses vector in memory_tracker. - // TODO(b/35213652): It may be profitable to share one rematerialized copy - // amongst more than one use. - std::vector remaining_uses_copy = - memory_tracker.RemainingUses(best); - for (HloInstruction* use : remaining_uses_copy) { - // Create a new rematerialized instruction in the HLO graph. - HloInstruction* remat = - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - - VLOG(3) << "Replacing use of " << best->name() << " in " << use->name() - << " with rematerialization " << remat->name(); - - TF_RETURN_IF_ERROR(best->ReplaceUseWith(use, remat)); + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); - // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker.RematerializeInstructionForUse(best, remat, use)); - - // Insert rematerialized instruction right before its use. - TF_RETURN_IF_ERROR(instruction_list.InsertBefore(remat, use)); - - // Add rematerialized instruction to remat_instructions so the - // rematerialized instruction is not rematerialized again. - remat_instructions.insert(remat); - - net_instructions_added++; + // Replace each remaining use of 'best' with the rematerialization. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker.IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " + << user->name() << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } } - // Original instruction should no longer be live at this point. All - // of its remaining uses are fed by rematerialized instructions. - TF_RET_CHECK(!memory_tracker.IsCurrentlyLive(best)); + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker.AddRematerializedInstruction(best, remat)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + std::vector place_before = remat->users(); + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (!memory_tracker.IsPlaced(operand_user) && operand_user != remat) { + place_before.push_back(operand_user); + } + } + } + instruction_list.InsertBeforeInstructions(remat, place_before); // If the rematerialized instruction is dead then rematerialization is // essentially a move. Don't delete the instruction now because we don't @@ -664,15 +1061,24 @@ StatusOr HloRematerialization::RematerializeComputation( // transformation because we keep maps with HloInstruction* values as // keys. if (best->users().empty()) { - VLOG(3) << best->name() << " is now dead"; - net_instructions_added--; + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + blacklist.insert(remat); + } + remat_move_instructions.insert(remat); + } else { + net_instructions_added++; } VLOG(3) << "memory_usage after rematerialization = " << memory_tracker.memory_usage(); } - const CallSite* callsite = call_graph_node->GetCallSite(instruction); + const CallSite* callsite = call_graph_node.GetCallSite(instruction); if (callsite != nullptr && callsite->context() == CallContext::kSequential && memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) { @@ -686,21 +1092,22 @@ StatusOr HloRematerialization::RematerializeComputation( // Recompute callee usage to account for any rematerialization performed // in the callee computations. - callee_usage = 0; for (HloComputation* called_computation : callsite->called_computations()) { - // Memory limit for the subcomputation is the memory limit less the - // amount of memory used at this point in the computation. - int64 subcomputation_memory_limit_bytes = std::max( - 0, memory_limit_bytes - memory_tracker.memory_usage()); - TF_ASSIGN_OR_RETURN( - bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, - subcomputation_memory_limit_bytes)); - changed |= subcomputation_changed; - - callee_usage += computation_peak_memory_.at(called_computation); + if (!ContainsKey(rematerialized_computations_, called_computation)) { + // Memory limit for the subcomputation is the memory limit less the + // amount of memory used at this point in the computation. + int64 subcomputation_memory_limit_bytes = std::max( + 0, memory_limit_bytes - memory_tracker.memory_usage()); + TF_ASSIGN_OR_RETURN( + bool subcomputation_changed, + RematerializeComputation(called_computation, sequence, + subcomputation_memory_limit_bytes)); + changed |= subcomputation_changed; + } } + TF_ASSIGN_OR_RETURN(callee_usage, + CalledComputationsMemoryUsage(instruction)); } peak_memory = std::max(peak_memory, @@ -710,37 +1117,33 @@ StatusOr HloRematerialization::RematerializeComputation( TF_RETURN_IF_ERROR(memory_tracker.EndInstruction()); } - if (peak_memory > memory_limit_bytes) { - LOG(WARNING) << "Can't reduce memory use of computation " - << computation->name() << " below " - << HumanReadableNumBytes(memory_limit_bytes) - << " by rematerialization (only reduced to " - << HumanReadableNumBytes(peak_memory) << ")"; - } - - // Verify that there are no more remaining uses. + // Verify some invariants on the memory tracker. + CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto& instruction : computation->instructions()) { - auto& remaining_uses = memory_tracker.RemainingUses(instruction.get()); - CHECK(remaining_uses.empty()) - << instruction->name() << " has remaining uses: " - << tensorflow::str_util::Join( - remaining_uses, ", ", [](string* out, HloInstruction* inst) { - tensorflow::strings::StrAppend(out, inst->name()); - }); + CHECK(memory_tracker.IsPlaced(instruction.get())); } - VLOG(1) << "Rematerialized " << remat_count << " instructions; " - << net_instructions_added << " net instructions added"; - VLOG(1) << "peak memory usage now " << HumanReadableNumBytes(peak_memory); + VLOG(1) << "In computation " << computation->name() << " rematerialized " + << remat_count << " instructions; " << net_instructions_added + << " net instructions added"; + VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory) + << " (was " + << HumanReadableNumBytes(computation_peak_memory_.at(computation)) + << ")"; // Update peak memory used by computation. - computation_peak_memory_[computation] = peak_memory; + computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. sequence->at(computation) .assign(instruction_list.instructions().begin(), instruction_list.instructions().end()); + rematerialized_computations_.insert(computation); + + instructions_rematerialized_ += remat_count; + net_instructions_added_ += net_instructions_added; + return changed; } @@ -753,18 +1156,38 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); - XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); + + // Adjust memory limit to account for the output of the entry + // computation. This is necessary because the per-computation accounting in + // MemoryUsageTracker do not include output as these are typically allocated + // by the caller. + int64 module_output_size = 0; + ShapeUtil::ForEachSubshape( + module->entry_computation()->root_instruction()->shape(), + [&module_output_size, this](const Shape& subshape, + const ShapeIndex& /*index*/) { + module_output_size += size_function_(subshape); + return Status::OK(); + }) + .IgnoreError(); + const int64 adjusted_memory_limit_bytes = + memory_limit_bytes - module_output_size; + VLOG(1) << "Adjusted memory limit accounting for output (" + << HumanReadableNumBytes(module_output_size) + << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); + + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( *module, [this](const LogicalBuffer& buffer) { return size_function_(buffer.shape()); })); - // Compute peak memory usage of all computations in the module called in a // sequential context. - TF_ASSIGN_OR_RETURN(call_graph_, CallGraph::Build(module)); + call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( [this, sequence](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { @@ -776,9 +1199,15 @@ StatusOr HloRematerialization::Run( return Status::OK(); })); + // The peak memory usage of the module equals the peak memory use of the entry + // computation plus the output size of the computation. This is because the + // peak memory for a computation does not include the output as this is + // typically accounted for in the caller. + const int64 before_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; VLOG(1) << "Peak memory usage of module (before): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + << HumanReadableNumBytes(before_peak_memory); // Run cost analysis. Operation cost is used in the heuristic for selecting // instructions for rematerialization. @@ -787,9 +1216,9 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, - RematerializeComputation(module->entry_computation(), - sequence, memory_limit_bytes)); + TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( + module->entry_computation(), sequence, + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -824,19 +1253,38 @@ StatusOr HloRematerialization::Run( computation->instruction_count()); } } - - VLOG(1) << "Peak memory usage of module (after): " - << HumanReadableNumBytes( - computation_peak_memory_[module->entry_computation()]); + VLOG(1) << "Rematerialized " << instructions_rematerialized_ + << " instructions in module " << module->name() << "; " + << net_instructions_added_ << " net instructions added"; + const int64 current_peak_memory = + computation_peak_memory_.at(module->entry_computation()) + + module_output_size; + VLOG(1) << "Peak memory usage of module now " + << HumanReadableNumBytes(current_peak_memory) << " (" + << current_peak_memory << " bytes), was " + << HumanReadableNumBytes(before_peak_memory) << " (" + << before_peak_memory << " bytes)"; + const int64 reduced_peak_memory = before_peak_memory - current_peak_memory; + VLOG(1) << "Reduced peak memory by " + << HumanReadableNumBytes(reduced_peak_memory) << " (" + << reduced_peak_memory << " bytes)"; XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + if (current_peak_memory > memory_limit_bytes) { + LOG(WARNING) << "Can't reduce memory use below " + << HumanReadableNumBytes(memory_limit_bytes) + << " by rematerialization (only reduced to " + << HumanReadableNumBytes(current_peak_memory) << ")"; + } + return changed; } /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence) { + const HloRematerialization::ShapeSizeFunction& size_function, + int64 memory_limit_bytes, HloModule* hlo_module, + SequentialHloOrdering::HloModuleSequence* sequence) { HloRematerialization remat(size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 86e1998b89454f75b1c10d0de2118fd1034c134d..1693f93183bc59c343e3c765cb4051566d4377ef 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -108,6 +109,23 @@ class HloRematerialization { // occurs. tensorflow::gtl::FlatMap computation_peak_memory_; + + std::unique_ptr points_to_analysis_; + + // Set of computations which have had rematerialization + // applied. Rematerialization is only applied once per computation. + tensorflow::gtl::FlatSet rematerialized_computations_; + + // Count of the total instructions rematerialized. + int64 instructions_rematerialized_ = 0; + + // Count of the net instructions added to the HLO module by + // rematerialization. This can be different than instructions_rematerialized_ + // because some rematerializations are effectively moves in the HLO + // schedule. In these cases, the rematerialization instruction replaces all + // uses of the original instruction and the original instruction is + // dead. Hence, no net instructions were added. + int64 net_instructions_added_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 0a4f2776891cfc932b4fc0627daaa9b5408f420a..2a1d728bc84067e6ad7f1f622216ab39b2b474d3 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -30,12 +31,16 @@ limitations under the License. namespace xla { namespace { -class HloOrderingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +using ::testing::_; + +class HloRematerializationTest : public HloTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -52,7 +57,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto negate = builder.AddInstruction( @@ -77,7 +82,7 @@ class HloOrderingTest : public HloTestBase { // Creates and returns a computation which includes a while and can benefit // from rematerialization. The computation looks like: // - // F32[1] %param = {...} + // F32[] %param = {...} // F32[1024] %bcast = broadcast(%param) // F32[1] %slice_1 = slice(%bcast, {0:1}) // F32[1] %while = while(%slice_1, while_body, while_cond) @@ -93,7 +98,7 @@ class HloOrderingTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); + HloInstruction::CreateParameter(0, scalar_shape_, "param")); auto bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); auto slice_1 = builder.AddInstruction( @@ -127,13 +132,14 @@ class HloOrderingTest : public HloTestBase { } // Various shapes used in the canned computations. + const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); }; // Test rematerialization of a single computation produced by // MakeRematerializableComputation. -TEST_F(HloOrderingTest, SingleComputation) { +TEST_F(HloRematerializationTest, SingleComputation) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -141,11 +147,9 @@ TEST_F(HloOrderingTest, SingleComputation) { // Find and save the original broadcast instruction which should be // rematerialized. const HloInstruction* slice = computation->root_instruction(); - ASSERT_EQ(HloOpcode::kSlice, slice->opcode()); + ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _))); const HloInstruction* concat = slice->operand(0); - ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode()); const HloInstruction* bcast = concat->operand(0); - ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode()); SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB @@ -161,8 +165,7 @@ TEST_F(HloOrderingTest, SingleComputation) { // The broadcast should have been rematerialized. const HloInstruction* remat_bcast = concat->operand(0); - EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode()); - EXPECT_NE(bcast, remat_bcast); + EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast))); // The rematerialized broadcast should be immediate before the concat in the // sequence. @@ -175,7 +178,7 @@ TEST_F(HloOrderingTest, SingleComputation) { // Test rematerialization of a single computation produced by // MakeRematerializableComputation but with a sufficiently high memory limit // such that no instructions are rematerialized. -TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { +TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloModule module(TestName()); HloComputation* computation = module.AddEntryComputation(MakeRematerializableComputation()); @@ -199,7 +202,7 @@ TEST_F(HloOrderingTest, SingleComputationNoRematerialization) { // only one computation needs to have an instruction rematerialized. The entry // computation should be the one chosen because rematerialization in the while // will presumably be more expensive. -TEST_F(HloOrderingTest, RematerializeAroundWhile) { +TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -237,7 +240,7 @@ TEST_F(HloOrderingTest, RematerializeAroundWhile) { // Test rematerialization of a computation which calls another computation via a // while. Both the entry computation and while body computation should have // computations rematerialized. -TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { +TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -271,7 +274,7 @@ TEST_F(HloOrderingTest, RematerializeEntryAndWhileBody) { // Test rematerialization of a doubly nested computation. All computations // should have an instruction rematerialized. -TEST_F(HloOrderingTest, RematerializeNestedComputations) { +TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloModule module(TestName()); auto cond_builder = HloComputation::Builder(TestName() + ".cond"); @@ -311,6 +314,203 @@ TEST_F(HloOrderingTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { + // Test that a single instruction is rematerialized several times. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call_1 = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call_1) + // F32[1024] %call_2 = call(SubComputation, {%add_2}) + // F32[1024] %add_3 = add(%bcast, call_2) + // F32[1024] %call_3 = call(Subcomputation, {%add_3}) + // F32[1024] %add_4 = add(%bcast, call_3) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across each call of Subcomputation (which requires + // 8KB) though the value is not used in the calls. Rematerializing %bcast + // across these calls reduces peak memory use from ~20KB down to ~16KB. + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto call_2 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_2}, subcomputation)); + auto add_3 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); + auto call_3 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_3}, subcomputation)); + auto add_4 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + auto count_broadcasts = [](const HloComputation* computation) { + int64 bcast_count = 0; + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBroadcast) { + bcast_count++; + } + } + return bcast_count; + }; + + // Before rematerialization there should be a single broadcast instruction in + // the graph. + EXPECT_EQ(count_broadcasts(entry_computation), 1); + EXPECT_EQ(entry_computation->instruction_count(), 9); + + EXPECT_EQ(add_2->operand(0), bcast); + EXPECT_EQ(add_3->operand(0), bcast); + EXPECT_EQ(add_4->operand(0), bcast); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + EXPECT_TRUE(changed); + + // The broadcast should have been rematerialized 3 times. + EXPECT_EQ(count_broadcasts(entry_computation), 4); + EXPECT_EQ(entry_computation->instruction_count(), 12); + + // The operands of add_2, add_3, and add_4 should all be rematerialized + // broadcasts. + EXPECT_NE(add_2->operand(0), bcast); + EXPECT_THAT(add_2->operand(0), op::Broadcast(param)); + EXPECT_NE(add_3->operand(0), bcast); + EXPECT_THAT(add_3->operand(0), op::Broadcast(param)); + EXPECT_NE(add_4->operand(0), bcast); + EXPECT_THAT(add_4->operand(0), op::Broadcast(param)); +} + +class IndirectUseTest : public HloRematerializationTest, + public ::testing::WithParamInterface {}; + +TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { + // Test that an rematerializable instruction is not rematerialized if it has + // an indirect use. Test is parameterized on whether the value has an indirect + // use, and the instruction should be rematerialized iff the value has no + // indirect use. Module: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %add_1 = add(%bcast, bcast) + // F32[1024] %call = call(Subcomputation, {%add_1}) + // F32[1024] %add_2 = add(%bcast, call) + // {F32[1024], F32[1024]} %tuple = tuple(%bcast, %add_2) + // F32[1024] %gte = GetTupleElememt(%tuple, 0) + // F32[1024] %negate = negate(%gte) + // + // Subcomputation: + // F32[1024] %param = {...} + // F32[2048] %concat = concat({%param, %param}) + // F32[1024] %slice = slice(%concat) + // + // The value %bcast is live across the call and rematerialization of %bcast + // across that point would reduce peak memory use by 4KB. However, %bcast is + // used indirectly in the %negate so rematerialization should not happen. + // + // This test is parameterized on whether the broadcast has an indirect use or + // not. The indirect use is controlled by the index of the GetTupleElement + // instruction. If the element is 0, then the %negate operand aliases %bcast + // (ie %bcast is used indirectly by %negate), otherwise the %negate operand + // aliases %add_2. + const bool indirectly_used = GetParam(); + HloModule module(TestName()); + + HloComputation* subcomputation = nullptr; + { + auto builder = HloComputation::Builder(TestName() + ".subcomputation"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1024_shape_, "param")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {param, param}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice( + vec1024_shape_, concat, /*start_indices=*/{0}, + /*limit_indices=*/{1024})); + subcomputation = module.AddEmbeddedComputation(builder.Build()); + } + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); + auto call_1 = builder.AddInstruction( + HloInstruction::CreateCall(vec1024_shape_, {add_1}, subcomputation)); + auto add_2 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({bcast, add_2})); + auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + vec1024_shape_, tuple, indirectly_used ? 0 : 1)); + builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, gte)); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(entry_computation->instruction_count(), 8); + + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSIGN_OR_ASSERT_OK( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, &module, &sequence)); + // Rematerialization should only occur if the rematerializable instruction has + // no indirect uses. + if (indirectly_used) { + EXPECT_FALSE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 8); + } else { + EXPECT_TRUE(changed); + EXPECT_EQ(entry_computation->instruction_count(), 9); + } +} + +INSTANTIATE_TEST_CASE_P(IndirectUseTestInstantiation, IndirectUseTest, + ::testing::Values(true, false)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b14eca5d1b36fbe8b863cb32d64c79fb56ce761 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +LIcensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::TensorShapeProto; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; +using ::tensorflow::str_util::Join; + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +string GetOpDefName(const HloInstruction* instruction) { + string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); + tensorflow::str_util::TitlecaseString(&name, "-"); + name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); + + if (instruction->opcode() == HloOpcode::kFusion) { + string fusion_name = ToString(instruction->fusion_kind()); + StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + } + return name; +} + +TensorShapeProto GetTensorShape(const HloInstruction* instruction) { + TensorShapeProto tensor_shape; + const Shape& shape = instruction->shape(); + for (auto dim : shape.dimensions()) { + tensor_shape.add_dim()->set_size(dim); + } + return tensor_shape; +} + +} // namespace + +void CleanNodeName(string* name) { + name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); + const string chars_to_replace = "<>[]"; + auto pred = [&](char c) { + return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != + chars_to_replace.end(); + }; + std::replace_if(name->begin(), name->end(), pred, '_'); +} + +Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { + VLOG(2) << "Adding computation " << computation.name(); + for (auto embedded : computation.MakeEmbeddedComputationsList()) { + for (auto& instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + } + for (auto& instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + return Status::OK(); +} + +const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } + +const string& HloTfGraphBuilder::GetNodeNameForInstruction( + const HloInstruction* instruction) { + if (ContainsKey(instruction_to_node_name_, instruction)) { + return instruction_to_node_name_[instruction]; + } + string node_name; + // If an instruction is fused, put it in the subgraph of the fusion; + // otherwise, put it in the computation subgraph. + if (instruction->IsFused()) { + node_name = GetNodeNameForInstruction(instruction->fusion_instruction()); + } else { + node_name = instruction->parent()->name(); + if (!instruction->metadata().op_name().empty()) { + // Always make computations contain TF ops but not the other way around. + StrAppend(&node_name, "/", instruction->metadata().op_name()); + } + } + string instruction_name = instruction->name(); + if (instruction->opcode() == HloOpcode::kParameter) { + StrAppend(&instruction_name, ".", instruction->parameter_number()); + } + StrAppend(&node_name, "/", instruction_name); + CleanNodeName(&node_name); + auto ret = + instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); + CHECK(ret.second); + return ret.first->second; +} + +void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, + NodeDef* node_def) const { + auto& attrs = *node_def->mutable_attr(); + + // Set the number of arguments for instructions that have variadic operands. + if (HloOpcodeIsVariadic(instruction->opcode())) { + tensorflow::AttrValue attr_value; + attr_value.set_i(instruction->operands().size()); + attrs["arg_num"] = attr_value; + } + + // Set the node type. + attrs["type"].set_s( + xla::PrimitiveType_Name(instruction->shape().element_type())); + + // Set the framework op (e.g. Tensorflow op) that generated this XLA op. + attrs["tf_op_type"].set_s(instruction->metadata().op_type()); + attrs["tf_op_name"].set_s(instruction->metadata().op_name()); + + // Set the shape of the output tensor. "_output_shapes" is a special attribute + // name used by Tensorboard for shapes of output tensors. + tensorflow::AttrValue shapes; + *shapes.mutable_list()->add_shape() = GetTensorShape(instruction); + attrs["_output_shapes"] = shapes; + + // Set the layout. + if (LayoutUtil::HasLayout(instruction->shape())) { + string layout_string; + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuples, emit the full shape because the layout of a tuple is not + // represented in a single Layout field. + layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); + } else { + layout_string = StrCat( + "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + } + attrs["layout"].set_s(layout_string); + } + + // Set op-specific attributes. + switch (instruction->opcode()) { + case HloOpcode::kConcatenate: + case HloOpcode::kBroadcast: + case HloOpcode::kReduce: + case HloOpcode::kReverse: + case HloOpcode::kTranspose: + for (auto dim : instruction->dimensions()) { + attrs["dims"].mutable_list()->add_i(dim); + } + break; + case HloOpcode::kGetTupleElement: + attrs["index"].set_i(instruction->tuple_index()); + break; + case HloOpcode::kRng: + attrs["dist"].set_s( + RandomDistribution_Name(instruction->random_distribution())); + break; + case HloOpcode::kConstant: + if (ShapeUtil::IsScalar(instruction->shape())) { + attrs["value"].set_s( + LiteralUtil::GetAsString(instruction->literal(), {})); + } + break; + case HloOpcode::kCustomCall: + attrs["custom_call_target"].set_s(instruction->custom_call_target()); + break; + default: + break; + } +} + +Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { + if (!visited_instructions_.insert(instruction).second) { + // Skip instructions that have already been added. + return Status::OK(); + } + + NodeDef* node_def = graph_def_.add_node(); + node_def->set_name(GetNodeNameForInstruction(instruction)); + node_def->set_op(GetOpDefName(instruction)); + SetNodeAttrs(instruction, node_def); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + } + } + // Add all edges including control edges. + for (unsigned i = 0; i < instruction->operands().size(); ++i) { + *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); + } + // Called computations are control dependencies. + for (const auto* called_computation : instruction->called_computations()) { + *node_def->add_input() = StrCat( + "^", GetNodeNameForInstruction(called_computation->root_instruction())); + } + return Status::OK(); +} + +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..b2c578af912ac0b777d1bc72a198504735a6b845 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace xla { +namespace hlo_graph_dumper { + +// This constructs a tensorflow graph for HLO computations. +class HloTfGraphBuilder { + public: + // Adds a computation to the graph. + Status AddComputation(const HloComputation& computation); + + const tensorflow::GraphDef& GetGraphDef() const; + + private: + // Gets the node name of an instruction. The node name is hierarchical. For + // example, if an instruction is fused, it will be put in a subgraph of the + // fusion instruction. + const string& GetNodeNameForInstruction(const HloInstruction* instruction); + + void SetNodeAttrs(const HloInstruction* instruction, + tensorflow::NodeDef* node_def) const; + + Status AddInstruction(const HloInstruction* instruction); + + tensorflow::GraphDef graph_def_; + // This records instructions that have been visited. + std::unordered_set visited_instructions_; + // A cache that maps instruction to the node name. + std::unordered_map instruction_to_node_name_; +}; + +// Cleans the node name to make it a valid name in a tensorflow graph. +void CleanNodeName(string* name); + +} // namespace hlo_graph_dumper +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6041debc4ae0ccbaad99bec9a461b640aeffbccf --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace xla { +namespace hlo_graph_dumper { +namespace { + +using ::tensorflow::GraphDef; + +class HloTfGraphBuilderTest : public HloTestBase { + protected: + HloTfGraphBuilderTest() {} + HloTfGraphBuilder generator_; + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation() { + auto builder = HloComputation::Builder("Negate"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + // Creates a computation which calls map with the given computation. + std::unique_ptr CreateMapComputation( + HloComputation *map_computation) { + auto builder = HloComputation::Builder("Map"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map_computation)); + return builder.Build(); + } + Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); +}; + +static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, + const string &attr_name) { + auto attr = node.attr().find(attr_name); + CHECK(attr != node.attr().end()); + return attr->second; +} + +TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { + auto builder = HloComputation::Builder("Concatenate"); + Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + const auto &node = graph_def.node(2); + EXPECT_EQ(node.name(), "Concatenate/concatenate"); + + // Check dimensions. + auto dims_value = GetNodeAttr(node, "dims"); + EXPECT_EQ(dims_value.list().i_size(), 1); + EXPECT_EQ(dims_value.list().i(0), 1); + + // Check shapes. + auto shape_value = GetNodeAttr(node, "_output_shapes"); + EXPECT_EQ(shape_value.list().shape_size(), 1); + EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); +} + +TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { + auto builder = HloComputation::Builder("Const"); + HloInstruction *instruction = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + OpMetadata metadata; + metadata.set_op_name("x"); + metadata.set_op_type("y"); + instruction->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 1); + const auto &node = graph_def.node(0); + EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); + EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); +} + +TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { + auto negate_computation = CreateNegateComputation(); + TF_CHECK_OK(generator_.AddComputation(*negate_computation)); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 2); + EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); + EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); + EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); + EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); + EXPECT_EQ(graph_def.node(1).input_size(), 1); + EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); +} + +TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, IncorparateTfOpsStructure) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + auto ge = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + OpMetadata metadata; + metadata.set_op_name("x/y"); + metadata.set_op_type("Y"); + ge->set_metadata(metadata); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/x/y/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { + // Create computations with a diamond-shaped callgraph. + auto negate_computation = CreateNegateComputation(); + auto map1_computation = CreateMapComputation(negate_computation.get()); + auto map2_computation = CreateMapComputation(negate_computation.get()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); + auto computation = builder.Build(); + TF_CHECK_OK(generator_.AddComputation(*computation)); + EXPECT_GT(generator_.GetGraphDef().node_size(), 0); +} + +} // namespace +} // namespace hlo_graph_dumper +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 035b570ed3419503ad2325c5fdb46118b5076187..de6081e57e7f27a07b314692c6935ecf3e3c54a9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -23,7 +23,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(instruction->parent() == computation.get()); if (instruction->opcode() == HloOpcode::kFusion) { for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == computation.get()) + TF_RET_CHECK(fused->parent() == + instruction->fused_instructions_computation()) << "Fused HLO was missing a parent: " << fused->ToString() << " parent: " << fused->parent() << " computation: " << computation.get(); diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 0054edcf6ab3b5134abbc43a8b326d56919364bc..a8d4ecf2614809d73f7c31eeab29b9e765bdeb4c 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -22,13 +22,16 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -59,11 +62,11 @@ TEST_F(InlinerTest, MapMax) { auto hlo_module = MakeUnique("test_module"); hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + Inliner inliner; EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); - root = hlo_module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kMaximum); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Maximum(lhs, rhs)); // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); @@ -97,7 +100,7 @@ TEST_F(InlinerTest, MapConstant) { Inliner inliner; EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); - EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index c162945bcae33f4e94b8fbd3a7e48bacce802925..5069215031bac496967cb446ee27dc3f44297df0 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -130,6 +130,33 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_->MakeInstructionPostOrder(); std::vector post_order(post_order_list.begin(), post_order_list.end()); + + std::set all_consumers_fusable; + // Find which ops can be fused into all of their operands. We would rather + // not fuse an op into only some of its users, as that offers no benefit in + // terms of memory bandwidth, but forces us to keep more live values around. + for (auto* hlo : post_order) { + auto user_fusable_into_hlo = [this, &hlo](HloInstruction* consumer) { + if (!consumer->IsFusable()) { + return false; + } + for (int operand_number = 0; + operand_number < consumer->operands().size(); ++operand_number) { + if (consumer->operand(operand_number) == hlo) { + if (!ShouldFuse(consumer, operand_number)) { + return false; + } + } + } + return true; + }; + + if (std::all_of(hlo->users().begin(), hlo->users().end(), + user_fusable_into_hlo)) { + all_consumers_fusable.insert(hlo); + } + } + tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { InsertOrDie(&post_order_index, post_order[i], i); @@ -216,6 +243,12 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); + + if (FusionWouldDuplicate(*operand, *instruction) && + (all_consumers_fusable.count(operand) == 0)) { + continue; + } + if (operand->IsFusable() && ShouldFuse(instruction, i)) { HloInstruction* fusion_instruction = Fuse(operand, instruction); diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index a4c269f0ebd40b2a1ab46619fec24e76ffd73ff0..9a79e4c38249323b1384192dfed81647d00b77b8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -15,8 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { using InstructionFusionTest = HloTestBase; @@ -60,7 +63,7 @@ TEST_F(InstructionFusionTest, InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, @@ -80,7 +83,7 @@ TEST_F(InstructionFusionTest, InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, @@ -100,7 +103,7 @@ TEST_F(InstructionFusionTest, InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode()); + EXPECT_THAT(computation->root_instruction(), op::Fusion()); } TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { @@ -151,4 +154,23 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {16, 16}), "0")); + HloInstruction* unary1 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kFloor, param0)); + builder.AddInstruction(HloInstruction::CreateSend(unary1, 0)); + HloInstruction* unary2 = builder.AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAbs, unary1)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_EQ(unary2, computation->root_instruction()); + EXPECT_FALSE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 5e7bd4a7ce8a1152973979d4a8fdb790a7fbd219..a8366ae794932464d11e9a44a8282c5b9a8a9013 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& out, } BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, - const LogicalBuffer& buffer) - : layout_(layout), buffer_(&buffer) { + const LogicalBuffer& buffer, + bool mandatory) + : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) { CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); } @@ -73,8 +74,9 @@ string BufferLayoutConstraint::ToString() const { OperandLayoutConstraint::OperandLayoutConstraint( const ShapeLayout& shape_layout, const HloInstruction* instruction, - int64 operand_no) - : shape_layout_(shape_layout), + int64 operand_no, bool mandatory) + : LayoutConstraint(mandatory), + shape_layout_(shape_layout), instruction_(instruction), operand_no_(operand_no) { CHECK(shape_layout_.LayoutIsSet()); @@ -124,7 +126,8 @@ bool LayoutConstraints::OperandBufferForwarded( } Status LayoutConstraints::SetBufferLayout(const Layout& layout, - const LogicalBuffer& buffer) { + const LogicalBuffer& buffer, + bool mandatory) { VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout); @@ -139,26 +142,38 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const Layout* curr_layout = BufferLayout(buffer); - if (curr_layout != nullptr) { - if (!LayoutUtil::Equal(*curr_layout, layout)) { + const BufferLayoutConstraint* curr_constraint = + GetBufferLayoutConstraint(buffer); + if (curr_constraint != nullptr) { + if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + if (curr_constraint->mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(*curr_layout).c_str(), + LayoutUtil::HumanString(curr_constraint->layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - // New constraint matches existing constraint. Nothing to do. - return Status::OK(); } - auto new_constraint_it = buffer_constraints_.insert( - {&buffer, BufferLayoutConstraint(layout, buffer)}); - added_constraints_.push_back(&new_constraint_it.first->second); + auto iter = buffer_constraints_.find(&buffer); + bool overwrite = iter != buffer_constraints_.end(); + if (!overwrite) { + iter = buffer_constraints_ + .insert(std::make_pair( + &buffer, BufferLayoutConstraint(layout, buffer, mandatory))) + .first; + } else { + iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true); + } + added_constraints_.push_back(&iter->second); // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 1); + TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == + static_cast(!overwrite)); unconstrained_buffer_ids_.erase(buffer.id()); return Status::OK(); @@ -166,23 +181,27 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, - int64 operand_no) { + int64 operand_no, bool mandatory) { VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " << operand_no << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); - const ShapeLayout* curr_shape_layout = OperandLayout(instruction, operand_no); + const OperandLayoutConstraint* curr_shape_layout = + GetOperandLayoutConstraint(instruction, operand_no); if (curr_shape_layout != nullptr) { - if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) { + if (curr_shape_layout->shape_layout().MatchesLayoutInShape( + shape_with_layout)) { + // New constraint matches existing constraint. Nothing to do. + return Status::OK(); + } + if (curr_shape_layout->mandatory()) { return FailedPrecondition( "Operand %lld of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", operand_no, instruction->name().c_str(), - curr_shape_layout->ToString().c_str(), + curr_shape_layout->shape_layout().ToString().c_str(), ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); } - // New constraint matches existing constraint. Nothing to do. - return Status::OK(); } // If any buffers in the operand occur in the output of the instruction, then @@ -196,22 +215,31 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } auto key = std::make_pair(instruction, operand_no); - auto new_constraint_it = operand_constraints_.insert( - {key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, - operand_no)}); - added_constraints_.push_back(&new_constraint_it.first->second); + auto iter = operand_constraints_.find(key); + if (iter == operand_constraints_.end()) { + auto pair = std::make_pair( + key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), + instruction, operand_no, mandatory)); + iter = operand_constraints_.insert(pair).first; + } else { + iter->second = + OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, + operand_no, /*mandatory=*/true); + } + added_constraints_.push_back(&iter->second); return Status::OK(); } Status LayoutConstraints::SetArrayOperandLayout( - const Layout& layout, const HloInstruction* instruction, int64 operand_no) { + const Layout& layout, const HloInstruction* instruction, int64 operand_no, + bool mandatory) { const HloInstruction* operand = instruction->operand(operand_no); TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - return SetOperandLayout(shape, instruction, operand_no); + return SetOperandLayout(shape, instruction, operand_no, mandatory); } Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { @@ -274,15 +302,29 @@ Status LayoutConstraints::SetInstructionLayout( const Layout* LayoutConstraints::BufferLayout( const LogicalBuffer& buffer) const { + if (const auto* constraint = GetBufferLayoutConstraint(buffer)) { + return &constraint->layout(); + } + return nullptr; +} +const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint( + const LogicalBuffer& buffer) const { auto it = buffer_constraints_.find(&buffer); - return it == buffer_constraints_.end() ? nullptr : &it->second.layout(); + return it == buffer_constraints_.end() ? nullptr : &it->second; } const ShapeLayout* LayoutConstraints::OperandLayout( const HloInstruction* instruction, int64 operand_no) const { + if (const auto* constraint = + GetOperandLayoutConstraint(instruction, operand_no)) { + return &constraint->shape_layout(); + } + return nullptr; +} +const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint( + const HloInstruction* instruction, int64 operand_no) const { auto it = operand_constraints_.find(std::make_pair(instruction, operand_no)); - return it == operand_constraints_.end() ? nullptr - : &it->second.shape_layout(); + return it == operand_constraints_.end() ? nullptr : &it->second; } const ShapeLayout* LayoutConstraints::ResultLayout() const { @@ -343,7 +385,8 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction.get(), 0)); + instruction->outfeed_shape(), instruction.get(), 0, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. @@ -375,7 +418,7 @@ Status LayoutAssignment::AddMandatoryConstraints( for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( called_computation_layout.parameter_layout(i).shape(), - instruction.get(), i)); + instruction.get(), i, /*mandatory=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -426,7 +469,8 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( body_layout.result_shape(), instruction.get())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction.get(), 0)); + body_layout.result_shape(), instruction.get(), 0, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. // For now we only support row major layouts for all inputs and outputs. @@ -450,7 +494,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape row_major_operand_shape(row_major_shape(operand_shape)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction.get(), i)); + row_major_operand_shape, instruction.get(), i, /*mandatory=*/true)); } } } @@ -659,44 +703,6 @@ LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) } } -namespace { - -// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of -// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the -// indices `to_delete` wherever in `indices` they are, and inserting the indices -// `to_insert` arbitrarily at the back. -tensorflow::protobuf::RepeatedField -DeleteAndInsertIndices( - std::vector to_delete, std::vector to_insert, - tensorflow::protobuf::RepeatedField indices) { - std::sort(to_delete.begin(), to_delete.end(), std::greater()); - std::sort(to_insert.begin(), to_insert.end(), std::less()); - for (auto index : to_delete) { - auto i = indices.begin(); - while (i != indices.end()) { - if (*i == index) { - i = indices.erase(i); - } else { - if (*i > index) { - (*i)--; - } - ++i; - } - } - } - for (auto index : to_insert) { - for (auto i = indices.begin(); i != indices.end(); ++i) { - if (*i >= index) { - (*i)++; - } - } - indices.Add(index); - } - return indices; -} - -} // namespace - std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Layout& output_layout, const HloInstruction* instruction, int64 operand_no) { @@ -705,7 +711,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape()) && ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && + if ((instruction->IsElementwiseOnOperand(operand_no) || + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) && !ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape())) { @@ -719,21 +726,32 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( } if (instruction->opcode() == HloOpcode::kReshape) { - // Pick the operand layout that makes the reshape a bitcast. If the reshape - // only inserts or deletes degenerate dimensions, we can easily compute the - // desired layout by accordingly inserting and deleting the elements in the - // minor-to-major list. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout operand_layout = LayoutUtil::MakeLayout( - AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices, - output_layout.minor_to_major()))); + // Prefer the operand layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the operand shape, there may be several such + // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + const Shape& output_shape = instruction->shape(); + Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + AsInt64Slice(output_layout.minor_to_major())); + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { + Shape operand_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, + output_shape_with_layout)) { + return MakeUnique(operand_shape_with_layout.layout()); + } + } + auto aligned_operand_shape = + ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); + if (aligned_operand_shape) { + auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); + LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); return MakeUnique(operand_layout); } } @@ -768,18 +786,32 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kReshape) { - // Pick the user layout that makes the reshape a bitcast. - bool merely_inserts_or_deletes_1_sized_dims; - std::vector inserted_indices, deleted_indices; - std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices, - inserted_indices) = - user->ReshapeMerelyInsertsOrDeletes1SizedDimensions(); - if (merely_inserts_or_deletes_1_sized_dims) { - Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice( - DeleteAndInsertIndices(deleted_indices, inserted_indices, - operand_layout.minor_to_major()))); + // Prefer the user layout that makes the reshape an bitcast. If any + // dimension bound is 1 in the user shape, there may be several such + // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // reshape is a bitcast when using the same layout. This may avoid copy + // operations. + Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + operand->shape().element_type(), + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(operand_layout.minor_to_major())); + const Shape& output_shape = user->shape(); + if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { + Shape output_shape_with_layout = + ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), + AsInt64Slice(output_shape.dimensions())); + if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, + operand_shape_with_layout)) { + return MakeUnique(output_shape_with_layout.layout()); + } + } + auto aligned_user_shape = + ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); + if (aligned_user_shape) { + auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( - LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); + LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); return MakeUnique(user_layout); } } @@ -936,7 +968,8 @@ Status LayoutAssignment::PropagateOperandConstraint( operand_constraint.shape_layout().layout(), user, operand_constraint.operand_no()); if (layout != nullptr) { - TF_RETURN_IF_ERROR(constraints->SetBufferLayout(*layout, *buffer)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(*layout, *buffer, /*mandatory=*/false)); } } return Status::OK(); @@ -966,11 +999,19 @@ Status LayoutAssignment::PropagateBufferConstraint( instruction, operand_no); if (operand_layout != nullptr) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - *operand_layout, instruction, operand_no)); + *operand_layout, instruction, operand_no, /*mandatory=*/true)); } } } } + return PropagateBufferConstraintToUses(buffer_constraint, constraints); +} + +Status LayoutAssignment::PropagateBufferConstraintToUses( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + const LogicalBuffer& buffer = buffer_constraint.buffer(); + TF_RET_CHECK(buffer.IsArray()); // Propagate the layout to all array uses of the logical buffer. This skips // uses of the buffer where the buffer is the element of a tuple. @@ -983,7 +1024,7 @@ Status LayoutAssignment::PropagateBufferConstraint( if (constraints->OperandLayout(user, operand_no) == nullptr && !constraints->OperandBufferForwarded(user, operand_no)) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( - buffer_constraint.layout(), user, operand_no)); + buffer_constraint.layout(), user, operand_no, /*mandatory=*/false)); } } @@ -1040,7 +1081,7 @@ StatusOr InferArrayLayout( *first_buffer_layout)) { // The points-to set is ambiguous for this index and the different source // buffers have different layouts. This case is possible in valid XLA - // computations because we do not propagate BufferLayoutConstaints to all + // computations because we do not propagate BufferLayoutConstraints to all // LogicalBuffers which may alias the constrained LogicalBuffer at some // point in the computation. return FailedPrecondition( @@ -1253,7 +1294,7 @@ Status LayoutAssignment::RunOnComputation( TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(computation->parent())); - // Construct LayoutConstaints with all layout constraints of the computation. + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(*points_to_analysis, computation); // Add constraints required for correctness on all backends (eg, entry @@ -1278,7 +1319,8 @@ Status LayoutAssignment::RunOnComputation( const LogicalBuffer& buffer = points_to_analysis->GetBuffer( *constraints.unconstrained_buffer_ids().begin()); TF_RETURN_IF_ERROR(constraints.SetBufferLayout( - LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer)); + LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer, + /*mandatory=*/false)); TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 61dc7b120752d57cf09423f38546441de2fc8dd9..689e4510ed2e0c32a194b8488d09c4d7af522d2b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -46,10 +46,16 @@ namespace xla { // gathered together in LayoutConstraints object. class LayoutConstraint { public: - LayoutConstraint() = default; + LayoutConstraint(bool mandatory) : mandatory_(mandatory) {} virtual ~LayoutConstraint() = default; virtual string ToString() const = 0; + + // True if this constraint cannot be overwritten by a different constraint. + bool mandatory() const { return mandatory_; } + + private: + bool mandatory_; }; std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); @@ -58,7 +64,8 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); // array produced by a particular instruction. class BufferLayoutConstraint : public LayoutConstraint { public: - BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer); + BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, + bool mandatory); const LogicalBuffer& buffer() const { return *buffer_; } const Layout& layout() const { return layout_; } @@ -66,7 +73,7 @@ class BufferLayoutConstraint : public LayoutConstraint { string ToString() const override; private: - const Layout layout_; + Layout layout_; const LogicalBuffer* buffer_; }; @@ -78,7 +85,8 @@ class BufferLayoutConstraint : public LayoutConstraint { class OperandLayoutConstraint : public LayoutConstraint { public: OperandLayoutConstraint(const ShapeLayout& shape_layout, - const HloInstruction* instruction, int64 operand_no); + const HloInstruction* instruction, int64 operand_no, + bool mandatory); const ShapeLayout& shape_layout() const { return shape_layout_; } const HloInstruction* instruction() const { return instruction_; } @@ -90,7 +98,7 @@ class OperandLayoutConstraint : public LayoutConstraint { string ToString() const override; private: - const ShapeLayout shape_layout_; + ShapeLayout shape_layout_; const HloInstruction* instruction_; int64 operand_no_; }; @@ -99,7 +107,7 @@ class OperandLayoutConstraint : public LayoutConstraint { class ResultLayoutConstraint : public LayoutConstraint { public: explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) - : shape_layout_(shape_layout) {} + : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {} const ShapeLayout& shape_layout() const { return shape_layout_; } string ToString() const override; @@ -124,8 +132,7 @@ class LayoutConstraints { // Return a vector containing the constraints which have been added to the // LayoutConstraints object since the construction of the object or since the // last time ConsumeAddedConstraints() has been called. This is used to - // identify - // newly added constraints when propagating layouts. + // identify newly added constraints when propagating layouts. std::vector ConsumeAddedConstraints() { std::vector ret_vec(std::move(added_constraints_)); added_constraints_.clear(); @@ -137,23 +144,29 @@ class LayoutConstraints { // instruction, or the layout of the result of the computation, respectively, // if it has been constrained. Otherwise return nullptr. const Layout* BufferLayout(const LogicalBuffer& buffer) const; + const BufferLayoutConstraint* GetBufferLayoutConstraint( + const LogicalBuffer& buffer) const; const ShapeLayout* OperandLayout(const HloInstruction* instruction, int64 operand_no) const; + const OperandLayoutConstraint* GetOperandLayoutConstraint( + const HloInstruction* instruction, int64 operand_no) const; const ShapeLayout* ResultLayout() const; // Add a constraint on the layout of a LogicalBuffer, the layout of the // operand of the instruction, or the layout of the result of the computation, // respectively. - Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer); + Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, + bool mandatory = true); Status SetOperandLayout(const Shape& shape_with_layout, - const HloInstruction* instruction, int64 operand_no); + const HloInstruction* instruction, int64 operand_no, + bool mandatory = true); Status SetResultLayout(const Shape& shape_with_layout); // Convenience wrapper around SetOperandLayout for setting the layout of a // operand using a Layout object. The operand must be array-shaped. Status SetArrayOperandLayout(const Layout& layout, const HloInstruction* instruction, - int64 operand_no); + int64 operand_no, bool mandatory = true); // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers // created by the instruction to the layouts in the given shape. The @@ -233,6 +246,18 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + // Propagates a buffer layout constraint into the operands that use it. + Status PropagateBufferConstraintToUses( + const BufferLayoutConstraint& layout_constraint, + LayoutConstraints* constraints); + + // Propagates a layout constraint on the use of the result of the given + // instruction to the definitions of the LogicalBuffers which make up the + // result. + Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, + const HloInstruction* instruction, + LayoutConstraints* constraints); + private: // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. @@ -248,6 +273,15 @@ class LayoutAssignment : public HloPassInterface { return Status::OK(); } + // This method can be overridden to mark instructions as requiring the operands + // to have the same layout as the result, for performance or correctness. This + // will propagate constraints through the instruction from the result into the + // operands. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + return false; + } + // Construct contraints and assign layouts to all instructions in the // computation satisfying the given ComputationLayout. Layouts constraints are // added, then propagated until all LogicalBuffers in the computation are @@ -267,13 +301,6 @@ class LayoutAssignment : public HloPassInterface { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); - // Propagates a layout constraint on the use of the result of the given - // instruction to the definitions of the LogicalBuffers which make up the - // result. - Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, - const HloInstruction* instruction, - LayoutConstraints* constraints); - // Chooses a layout of operand `operand_no` of `instruction` that minimizes // the cost of `instruction`. `output_layout` is the layout of `instruction`. // Returns null if it can't decide the best layout. diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6361907b0e4ad8e21baec88b975f88fc65e42b38..bfb9e4ac2ee707233a82c9cd8dc5e3cc0e5ff8e7 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -26,10 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -38,9 +40,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::ElementsAre; + class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, @@ -304,18 +310,16 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}), root->operand(1)->shape())); - // Verify some of the structure of the HLO graph. - EXPECT_EQ(constant, root->operand(0)->operand(0)); - EXPECT_EQ(HloOpcode::kCopy, root->operand(1)->operand(0)->opcode()); - EXPECT_EQ(HloOpcode::kConstant, - root->operand(1)->operand(0)->operand(0)->opcode()); + // Verify the structure of the HLO graph. + EXPECT_THAT(root, + op::Tuple(op::Tuple(constant), op::Tuple(op::Copy(constant)))); } TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { // param -> log -> reshape -> tanh auto builder = HloComputation::Builder(TestName()); Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1}); - Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3}); + Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2}); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, ashape, "param")); auto log = builder.AddInstruction( @@ -330,8 +334,8 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { Shape ashape_with_layout(ashape); Shape bshape_with_layout(bshape); - *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2}); + *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3}); + *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_parameter_layout(0) = @@ -341,12 +345,12 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(log_minor_to_major, 1), + EXPECT_GT(PositionInContainer(log_minor_to_major, 1), PositionInContainer(log_minor_to_major, 2)); auto reshape_minor_to_major = AsInt64Slice(reshape->shape().layout().minor_to_major()); - EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0), + EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0), PositionInContainer(reshape_minor_to_major, 2)); } @@ -419,8 +423,8 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(output_shape_with_layout); AssignLayouts(&module, &computation_layout); - EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1, 2})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), + ElementsAre(0, 1, 2)); } TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { @@ -472,15 +476,80 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { {transpose_shape_with_layout, broadcast2_shape_with_layout})); AssignLayouts(&module, &computation_layout); - EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1})); - EXPECT_TRUE(ContainersEqual(transpose->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{1, 0})); - EXPECT_TRUE(ContainersEqual(tanh->shape().layout().minor_to_major(), - tensorflow::gtl::ArraySlice{0, 1})); + EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); + EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); + EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1)); } -// Add test which fails due to copy tuple. +class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment { + public: + explicit OperandsMustBeTheSameLayoutAssignment( + ComputationLayout* entry_computation_layout) + : LayoutAssignment(entry_computation_layout) {} + + protected: + Status PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) override { + const LogicalBuffer& buffer = buffer_constraint.buffer(); + const HloInstruction* instruction = buffer.instruction(); + + // Force the operands' layout to the output layout. + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const HloInstruction* operand = instruction->operand(operand_no); + if (ShapeUtil::Rank(instruction->shape()) != + ShapeUtil::Rank(operand->shape())) { + continue; + } + TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( + buffer_constraint.layout(), instruction, operand_no, + /*mandatory=*/true)); + } + return PropagateBufferConstraintToUses(buffer_constraint, constraints); + } +}; + +TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { + // param0 -> concatenate -> reshape + // param1 -^ + auto builder = HloComputation::Builder(TestName()); + Shape ashape = ShapeUtil::MakeShape(F32, {50, 1}); + Shape bshape = ShapeUtil::MakeShape(F32, {50, 2}); + Shape cshape = ShapeUtil::MakeShape(F32, {100}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "param")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "param")); + auto concatenate = builder.AddInstruction( + HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1)); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(cshape, concatenate)); + HloModule module(TestName()); + HloComputation* computation = + module.AddEntryComputation(builder.Build(reshape)); + + Shape param0_shape_with_layout(ashape); + Shape param1_shape_with_layout(ashape); + *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(param0_shape_with_layout); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(param1_shape_with_layout); + OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); + EXPECT_IS_OK(layout_assignment.Run(&module).status()); + + EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); + EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(), + ElementsAre(1, 0)); + EXPECT_THAT(concatenate->shape().layout().minor_to_major(), + ElementsAre(1, 0)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index caaf56a5516fcf9f21d8754feec04db23381809e..e1004256ff63bd7ef58fbba8e4144b7c1a71d32b 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,8 +28,9 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); @@ -98,15 +99,53 @@ std::vector> GetAllUsesOfInstructionAtIndex( return uses; } +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index, + const TuplePointsToAnalysis& points_to_analysis) { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = GetAllUsesOfInstructionAtIndex( + fused_param, operand_index, points_to_analysis); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + } // namespace // User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following two qualifications: -// *) Is element-wise. +// and layout, and 'user' meets one of the following qualifications: +// *) Is element-wise. Or... // *) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. -// *) Use of 'operand' is DynamicUpdateSlice at operand index 0. +// at operand 0. Or... +// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, @@ -120,31 +159,49 @@ bool CanShareOperandBufferWithUser( if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } - // Check if 'user' is a loop fusion instruction with a kDynamicUpdateSlice - // fused root instruction. - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - for (auto& fused_param : user->fused_parameters()) { - // Find fusion parameter associated with 'operand'. - if (user->operand(fused_param->parameter_number()) != operand) { - continue; - } - // Get all uses of 'operand' at 'index' from 'user.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root at operand index 0. - if (fused_param_uses.size() == 1 && - fused_param_uses[0].first == user->fused_expression_root() && - fused_param_uses[0].second == 0) { - return true; + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, + points_to_analysis); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is either kDot, or nested + // kFusion of kind kTransposeDot. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kDot || + (operand->opcode() == HloOpcode::kFusion && + operand->fusion_kind() == + HloInstruction::FusionKind::kTransposeDot); + }); + if (add_operand_it == add->operands().end()) { + return false; } - break; + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index, + points_to_analysis); } - return false; - } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, // so here we just need to check that the use is at operand index 0. std::vector operand_indices = user->OperandIndices(operand); diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 410a7b1b519e117f21c01938cb8e4a5b1c358ad2..52de282ca6b444867c865f845ce794196c98b277 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -32,8 +32,9 @@ namespace xla { // 'operand'. Returns false otherwise. // // REQUIRES: 'operand' is an operand of 'user'. -bool DoesNotUseOperandBuffer(HloInstruction* operand, const ShapeIndex& index, - HloInstruction* user, +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); // Returns true if 'user' (at 'user_index') can share a buffer with its operand diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2ff71d6f3c8eff58b83783fc867d5874c6c700a3..ac670069b499eadd452f7faf3a56aa00d808d77f 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -34,9 +34,7 @@ class PointsToAnalysisTestBase : public HloTestBase { void RunAnalysis() { CHECK_NOTNULL(module_.get()); points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get(), - /*include_loop_fusion_instructions=*/true) - .ConsumeValueOrDie(); + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -150,6 +148,25 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { *points_to_analysis_)); } +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); +} + TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { auto builder = HloComputation::Builder(TestName()); @@ -185,5 +202,167 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { *points_to_analysis_)); } +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto b_t = builder.AddInstruction( + HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + + auto nested_fusion = computation_->CreateFusionInstruction( + {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion = computation_->CreateFusionInstruction( + {add, nested_fusion}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused transpose-dot-add should be share buffer with 'add_operand'. + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, + *points_to_analysis_)); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = MakeUnique(TestName()); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 28488ca99912335a4ead43c9c7cd227f85f7db68..964b359bb094b43a1a8b126a217293567c5fc865 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -130,7 +130,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, int alignment = 0); -// Creates a basic block with the same context and funtion as for the +// Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 17d7b97b21bd3296711295e0779b0a273c9917e0..78d21233c765ec8f18a865f55b752d418ad126d6 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -60,9 +60,12 @@ namespace xla { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform) + .set_number_of_replicas(options.number_of_replicas()) + .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); + TF_ASSIGN_OR_RETURN(std::unique_ptr backend, + Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); @@ -77,21 +80,6 @@ LocalService::LocalService(std::unique_ptr execute_backend, runs_in_client_process_ = true; } -tensorflow::Status LocalService::ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs) { - TF_ASSIGN_OR_RETURN(std::vector arg_allocations, - ResolveAndValidateArguments( - arguments, execute_backend_.get(), device_ordinal)); - argument_ptrs->resize(arg_allocations.size()); - for (int i = 0; i < arguments.size(); ++i) { - const Allocation& allocation = *arg_allocations[i]; - (*argument_ptrs)[i] = allocation.device_memory(); - } - return tensorflow::Status::OK(); -} - namespace { // Returns the space required to allocate a shape. If // allocate_space_for_deep_copy the space includes all sub-buffers of @@ -128,70 +116,6 @@ StatusOr LocalService::AllocateBufferOnDevice( allocation_size)); } -StatusOr>> -LocalService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - std::vector> module_configs; - for (const AheadOfTimeComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - // Dump computation proto state if flag is set. - legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); - const string& directory_path = flags->xla_dump_computations_to; - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, - /*include_unreachable_instructions=*/true)); - hlo_modules.push_back(std::move(hlo_module)); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - module_configs.push_back(MakeUnique(*program_shape)); - HloModuleConfig* module_config = module_configs.back().get(); - auto* computation_layout = - module_config->mutable_entry_computation_layout(); - if (flags->xla_hlo_profile) { - module_config->enable_hlo_profiling(true); - } - for (int i = 0; i < instance.argument_layouts.size(); ++i) { - const Shape& argument_layout = *instance.argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); - } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - *instance.result_layout)); - } - - return execute_backend_->compiler()->CompileAheadOfTime( - std::move(hlo_modules), std::move(module_configs), MakeHloDumper(), - options); -} - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index df27f0a7a60dca99caf09994f417f1bc45ec15de..767a3ab697febb283af448b25369445152381a5e 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -43,14 +43,6 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // For an array of arguments, validate that each is placed on the - // specified device_ordinal, and return the DeviceMemoryBase - // corresponding to each argument. - tensorflow::Status ResolveArguments( - const tensorflow::gtl::ArraySlice arguments, - int device_ordinal, - std::vector* argument_ptrs); - // Return a handle to a buffer large enough to hold shape, allocated // on device_ordinal. If allocate_space_for_deep_copy, the buffer is // large enough to hold all sub-buffers of a tuple shape, otherwise @@ -59,22 +51,6 @@ class LocalService : public Service { const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy); - // A description of a computation to compile using CompileAheadOfTime. - struct AheadOfTimeComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |LocalClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice - computations, - const AotCompilationOptions& Options); - // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a // result of the given layout. diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 3bff35544c83b09557e5623b10304348a41ec336..768977ba6bba2f9af55fcd467aa3d91488e4bf0f 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -13,17 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/reshape_mover.h" - -#include -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -namespace { - +// Implementation note: +// // The general idea behind this pass is that we're converting from this: // %param.A = OldShape // %param.B = OldShape @@ -44,6 +35,19 @@ namespace { // only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or // transposes to a scalar should be cheap, we simply never move them. +#include "tensorflow/compiler/xla/service/reshape_mover.h" + +#include +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + // Finds the first non-scalar operand of an instruction that is a reshape or // transpose and returns the operand if it is found or nullptr if not found. HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { @@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) { if (!ShapeUtil::IsScalar(operand->shape()) && (operand->opcode() == HloOpcode::kReshape || operand->opcode() == HloOpcode::kTranspose)) { + VLOG(5) << "Found first non-scalar reshape operand of " + << hlo->ToStringNoMetadata() << ":\n\t" + << operand->ToStringNoMetadata(); return operand; } } @@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction, // A constant can trivially reshape the literal it holds. if (operand->opcode() == HloOpcode::kConstant && ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); return true; } @@ -116,135 +126,173 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( if (!first_reshape_operand) { return false; } - return (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty() && - // Check whether all operands: - // 1. are all reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. can be any shape like kConstant, kRng, and scalars. - std::all_of( - operands.begin(), operands.end(), - [instruction, - first_reshape_operand](const HloInstruction* operand) { - return AreEquivalentReshapes(first_reshape_operand, operand) || - OperandCanTrivallyChangeShape(instruction, operand); - }); + VLOG(3) << "** Checking whether instruction is an elementwise operation of " + "equivalent reshapes/transposes: " + << instruction->ToStringNoMetadata(); + bool result = + (instruction->user_count() > 0 || + instruction == instruction->parent()->root_instruction()) && + instruction->IsElementwise() && !operands.empty() && + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, it may be + // implicitly broadcast, which can confound the movement's + // correctness. + // 1. Are all reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Can be any shape like kConstant, kRng, and scalars. + std::all_of( + operands.begin(), operands.end(), + [instruction, first_reshape_operand](const HloInstruction* operand) { + if (!ShapeUtil::SameDimensions(operand->shape(), + instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() << "\n\tinstruction: " + << instruction->ToStringNoMetadata(); + return false; + } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + return true; + } + if (OperandCanTrivallyChangeShape(instruction, operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToStringNoMetadata(); + return true; + } + return false; + }); + VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " + << instruction->ToStringNoMetadata() << ": " << result; + return result; } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are equivalent // reshapes or transposes. -bool TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { - std::vector operands = instruction->operands(); - HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); - CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); - for (size_t i = 0; i < operands.size(); ++i) { - // All scalar operands remain as-is, even if they're reshape or transpose, - // to simplify handling wrt special scalar broadcast rules for ops like - // Select. Scalar reshapes should be cheap anyways. - if (ShapeUtil::IsScalar(operands[i]->shape())) { - continue; - } - auto element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); +StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, + HloInstruction* instruction) { + if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + return false; + } + + std::vector operands = instruction->operands(); + HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction); + TF_RET_CHECK(old_reshape != nullptr); + Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + + VLOG(3) << "** Trying to sink reshape or transpose: " + << instruction->ToStringNoMetadata() + << "\n\told reshape: " << old_reshape->ToStringNoMetadata() + << "\n\tnew elementwise shape: " + << ShapeUtil::HumanString(new_elementwise_shape); + for (size_t i = 0; i < operands.size(); ++i) { + // All scalar operands remain as-is, even if they're reshape or transpose, + // to simplify handling wrt special scalar broadcast rules for ops like + // Select. Scalar reshapes should be cheap anyways. + if (ShapeUtil::IsScalar(operands[i]->shape())) { + continue; + } + PrimitiveType element_type = operands[i]->shape().element_type(); + switch (operands[i]->opcode()) { + case HloOpcode::kConstant: { + if (old_reshape->opcode() == HloOpcode::kReshape) { + VLOG(3) << "Creating reshape for kConstant operand " << i << ": " + << operands[i]->ToStringNoMetadata(); + operands[i] = instruction->parent()->AddInstruction( + HloInstruction::CreateReshape( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i])); + } else { + TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); + std::vector inverse_permutation = + InversePermutation(old_reshape->dimensions()); operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( + HloInstruction::CreateTranspose( ShapeUtil::ChangeElementType(new_elementwise_shape, element_type), - operands[i]->operands())); - break; + operands[i], inverse_permutation)); } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; + break; } - } - if (HloOpcode::kFusion == instruction->opcode()) { - // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. - for (const auto& fused_instruction : instruction->fused_instructions()) { - Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); + case HloOpcode::kRng: { + CHECK_EQ(operands[i]->user_count(), 1); + operands[i] = instruction->parent()->AddInstruction( + operands[i]->CloneWithNewOperands( + ShapeUtil::ChangeElementType(new_elementwise_shape, + element_type), + operands[i]->operands())); + break; } - } - auto new_elementwise = - computation->AddInstruction(instruction->CloneWithNewOperands( - // `instruction` may change the element type, e.g., from - // operands[0] -> reshape -> convert (`instruction`) - // to - // operands[0] -> convert' -> reshape' - // - // In this case, convert' should have the same element type as - // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, - instruction->shape().element_type()), - operands)); - std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { case HloOpcode::kReshape: - new_reshape = HloInstruction::CreateReshape(instruction->shape(), - new_elementwise); - break; case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + operands[i] = operands[i]->mutable_operand(0); break; default: - LOG(FATAL) << "Bad opcode"; + LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " + "transposes."; } - TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction, - std::move(new_reshape))); - return true; } - return false; + if (HloOpcode::kFusion == instruction->opcode()) { + // Here we already know `instruction` is elementwise, and no operand is + // implicit broadcast as if it were the operands would not be equivalent + // reshapes, so all the fused instructions have the same dimensions. + for (const auto& fused_instruction : instruction->fused_instructions()) { + Shape* shape = fused_instruction->mutable_shape(); + *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); + *shape->mutable_layout() = new_elementwise_shape.layout(); + } + } + HloInstruction* new_elementwise = + computation->AddInstruction(instruction->CloneWithNewOperands( + // `instruction` may change the element type, e.g., from + // operands[0] -> reshape -> convert (`instruction`) + // to + // operands[0] -> convert' -> reshape' + // + // In this case, convert' should have the same element type as + // `convert` and the same dimensions as operands[0]. + ShapeUtil::ChangeElementType(new_elementwise_shape, + instruction->shape().element_type()), + operands)); + + std::unique_ptr new_reshape; + switch (old_reshape->opcode()) { + case HloOpcode::kReshape: + VLOG(3) << "Creating new reshape for new elementwise op: " + << new_elementwise->ToStringNoMetadata(); + new_reshape = + HloInstruction::CreateReshape(instruction->shape(), new_elementwise); + break; + case HloOpcode::kTranspose: + new_reshape = HloInstruction::CreateTranspose( + instruction->shape(), new_elementwise, old_reshape->dimensions()); + break; + default: + LOG(FATAL) << "Bad opcode"; + } + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + instruction, std::move(new_reshape))); + return true; } } // namespace StatusOr ReshapeMover::Run(HloModule* module) { - return std::any_of( - module->computations().begin(), module->computations().end(), - [](const std::unique_ptr& computation) { - std::list postorder = - computation->MakeInstructionPostOrder(); - return std::any_of(postorder.begin(), postorder.end(), - [&computation](HloInstruction* instruction) { - return TrySinkReshapeOrTranspose(computation.get(), - instruction); - }); - }); + bool changed = false; + for (const auto& comp : module->computations()) { + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool did_change, + TrySinkReshapeOrTranspose(comp.get(), instruction)); + changed |= did_change; + } + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 1862e2e992ec7ca9fac7444e6b83018fd1f17372..5217e85d4fc12e2adc412644b8f11fd11a58039a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -20,14 +20,18 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { using ReshapeMoverTest = HloTestBase; @@ -43,14 +47,19 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); } TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { @@ -64,14 +73,20 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT( + computation->root_instruction(), + op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter()))); } TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { @@ -85,18 +100,20 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(param1))); EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); - auto new_root = computation->root_instruction(); - EXPECT_NE(add, new_root); - EXPECT_EQ(HloOpcode::kReshape, new_root->opcode()); - EXPECT_EQ(root_shape.DebugString(), new_root->shape().DebugString()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Add(param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); } TEST_F(ReshapeMoverTest, ConstantAndReshapeMoved) { @@ -108,18 +125,21 @@ TEST_F(ReshapeMoverTest, ConstantAndReshapeMoved) { LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); - auto add = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(add, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), const1)); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); - auto new_root = computation->root_instruction(); - EXPECT_NE(add, new_root); - EXPECT_EQ(HloOpcode::kReshape, new_root->opcode()); - EXPECT_EQ(root_shape.DebugString(), new_root->shape().DebugString()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Add(param0, op::Reshape(const1)))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { @@ -141,13 +161,16 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto fusion = computation->AddInstruction(HloInstruction::CreateFusion( add->shape(), HloInstruction::FusionKind::kLoop, add)); TF_CHECK_OK(computation->ReplaceInstruction(add, fusion)); - EXPECT_EQ(fusion, computation->root_instruction()); + + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Reshape(param0), op::Reshape(param1))); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); - auto new_root = computation->root_instruction(); - EXPECT_NE(fusion, new_root); - EXPECT_EQ(HloOpcode::kReshape, new_root->opcode()); - EXPECT_EQ(root_shape.DebugString(), new_root->shape().DebugString()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Fusion(param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); } TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { @@ -166,18 +189,22 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); auto reshape_pred = builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred)); - auto select = builder.AddInstruction(HloInstruction::CreateTernary( + builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(select, computation->root_instruction()); + + EXPECT_THAT( + computation->root_instruction(), + op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); - auto new_root = computation->root_instruction(); - EXPECT_NE(select, new_root); - EXPECT_EQ(HloOpcode::kReshape, new_root->opcode()); - EXPECT_EQ(root_shape.DebugString(), new_root->shape().DebugString()); + EXPECT_THAT(computation->root_instruction(), + op::Reshape(op::Select(pred, param0, param1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); } TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { @@ -197,10 +224,119 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto module = MakeUnique(TestName()); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(select, computation->root_instruction()); + EXPECT_THAT(computation->root_instruction(), + op::Select(op::Reshape(pred), param0, param1)); + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Select(op::Reshape(pred), param0, param1)); EXPECT_EQ(select, computation->root_instruction()); } +// Tree looks like: +// +// param0 [1,128,1] +// | +// reshape [128,1] constant [128,1024] +// \ / +// multiply w/implicit broadcast [128,1024] +// +// The reshape mover would like to sink the reshape below the multiply. +// +// Previously we would attempt to insert a reshape of the constant to [1,128,1] +// (which is unsound, because it has a different number of elements) as +// preparation for sinking the reshape. +// +// To eliminate the unsoundness, we outlaw reshape sinking when one of the +// operands is implicitly broadcast in the elementwise consumer. +// +// TODO(b/37799338) However, it would be possible in this case to do a more +// in-depth analysis to get reshape movement to occur: +// +// 1. Note that the broadcast dimension (logical dimension 1) in the operands +// would map back to logical dimension 2 in the param0 node. +// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 +// dimension). +// 3. Reshape to [128,1024] at the root. +// +// But this is not currently done. +TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); + auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {128, 1}), param0)); + Array2D a(128, 1024); + auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kMultiply, constant, reshape)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + + EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Constant(), op::Reshape(param0))); + EXPECT_EQ(multiply, computation->root_instruction()); +} + +// Tree looks like this: +// +// add1 +// | +// +- reshape2 - param2 +// | +// +- reshape3 - add0 +// | +// + reshape0 - param0 +// | +// + reshape1 - param1 +// +// We expect reshape{0,1} AND reshape{2,3} to be lifted. +TEST_F(ReshapeMoverTest, MultiplePasses) { + auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); + auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); + auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); + HloComputation::Builder builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape1, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape1, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, shape2, "param2")); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1)); + auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( + shape2, HloOpcode::kAdd, reshape0, reshape1)); + auto reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2)); + auto reshape3 = + builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0)); + builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, + reshape2, reshape3)); + + auto module = MakeUnique(TestName()); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT( + computation->root_instruction(), + op::Add(op::Reshape(param2), + op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); + + EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 451bb8c7eadf3e2210788a722d8f75aa3050e30f..8b373ab09623a930a7a15aab0e39d37af0995250 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -112,6 +112,16 @@ ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } +ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads( + int num_threads) { + intra_op_parallelism_threads_ = num_threads; + return *this; +} + +int ServiceOptions::intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; +} + /* static */ StatusOr> Service::NewService( perftools::gputools::Platform* platform) { ServiceOptions default_options; @@ -126,9 +136,10 @@ int ServiceOptions::number_of_replicas() const { return number_of_replicas_; } if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } - TF_ASSIGN_OR_RETURN( - execute_backend, - Backend::CreateBackend(platform, options.number_of_replicas())); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(options.number_of_replicas()); + TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new Service( @@ -142,7 +153,10 @@ Service::CreateComputeConstantBackend() { PlatformUtil::GetSupportedPlatforms()); for (auto* platform : platforms) { if (platform->id() == se::host::kHostPlatformId) { - return Backend::CreateBackend(platform, /*replica_count=*/1); + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_options.set_number_of_replicas(1); + return Backend::CreateBackend(backend_options); } } return NotFound("CPU platform not found"); @@ -180,20 +194,24 @@ Service::Service(std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) : execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { - LOG(INFO) << Printf( - "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); - for (int i = 0; i < execute_backend_->device_count(); ++i) { - if (execute_backend_->device_ordinal_supported(i)) { - se::StreamExecutor* executor = - execute_backend_->stream_executor(i).ValueOrDie(); - const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); - } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + if (execute_backend_) { + LOG(INFO) << Printf( + "XLA service %p executing computations on platform %s. Devices:", this, + execute_backend_->platform()->Name().c_str()); + for (int i = 0; i < execute_backend_->device_count(); ++i) { + if (execute_backend_->device_ordinal_supported(i)) { + se::StreamExecutor* executor = + execute_backend_->stream_executor(i).ValueOrDie(); + const auto& description = executor->GetDeviceDescription(); + LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, + description.name().c_str(), + description.platform_version().c_str()); + } else { + LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + } } + } else { + VLOG(1) << "XLA compile-only service constructed"; } } @@ -286,7 +304,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + const ExecutionOptions& execution_options, Backend* backend) { auto module_config = MakeUnique(program_shape); auto* computation_layout = module_config->mutable_entry_computation_layout(); @@ -326,7 +344,7 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(backend->Replicas().size()); module_config->set_fast_math_disabled(execution_options.disable_fast_math()); module_config->set_seed(execution_options.seed()); @@ -367,20 +385,23 @@ StatusOr>> Service::BuildExecutables( VLOG(1) << versioned_handle; } + CHECK_EQ(versioned_handles.size(), module_configs.size()); std::vector> modules; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { + for (int64 i = 0; i < versioned_handles.size(); ++i) { + const VersionedComputationHandle& versioned_handle = versioned_handles[i]; + const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, computation_tracker_.BuildHloModule( - versioned_handle, + versioned_handle, &config, /*include_unreachable_instructions=*/true)); modules.push_back(std::move(module)); } Compiler::HloDumper hlo_dumper = MakeHloDumper(); - TF_ASSIGN_OR_RETURN(std::vector> executables, - backend->compiler()->Compile( - std::move(modules), std::move(module_configs), - hlo_dumper, std::move(executors))); + TF_ASSIGN_OR_RETURN( + std::vector> executables, + backend->compiler()->Compile(std::move(modules), hlo_dumper, + std::move(executors))); if (!other_directory_path.empty()) { for (size_t i = 0; i < versioned_handles.size(); ++i) { @@ -423,7 +444,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, + computation_tracker_.BuildHloModule(versioned_handle, module_config.get(), /*include_unreachable_instructions=*/ !executable_for_compute_constant)); @@ -435,8 +456,7 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), std::move(module_config), - hlo_dumper, executor)); + backend->compiler()->Compile(std::move(module), hlo_dumper, executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -474,7 +494,7 @@ StatusOr> Service::BuildAndCacheExecutable( std::unique_ptr executable_unique_ptr, BuildExecutable(versioned_handle, std::move(module_config), /*executable_for_compute_constant=*/false, arguments, - execute_backend_.get(), executor)); + backend, executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -569,21 +589,21 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); + run_options.emplace_back(options, backend->StreamBorrower(), + backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; if (backend->Replicas().size() == 1) { TF_ASSIGN_OR_RETURN( - result, - ExecuteOnStreamWrapper>( - executable, &run_options[0], profile, execute_backend_.get(), - [&arguments](Executable* executable, - const ServiceExecutableRunOptions* run_options, - HloExecutionProfile* hlo_execution_profile) { - return executable->ExecuteOnStream(run_options, arguments, - hlo_execution_profile); - })); + result, ExecuteOnStreamWrapper>( + executable, &run_options[0], profile, backend, + [&arguments](Executable* executable, + const ServiceExecutableRunOptions* run_options, + HloExecutionProfile* hlo_execution_profile) { + return executable->ExecuteOnStream(run_options, arguments, + hlo_execution_profile); + })); } else { std::vector< tensorflow::gtl::ArraySlice> @@ -666,7 +686,8 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options())); + request.execution_options(), + execute_backend_.get())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -751,9 +772,10 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -818,9 +840,10 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arg_allocations, + arg->execution_options(), execute_backend_.get())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1141,7 +1164,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + compute_constant_backend_.get())); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, @@ -1202,7 +1226,8 @@ tensorflow::Status Service::GetComputationStats( user_computation->GetVersionedHandle(); TF_ASSIGN_OR_RETURN(std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle)); + computation_tracker_.BuildHloModule(versioned_handle, + /*config=*/nullptr)); MakeHloDumper()(*module, "computation statistics subject"); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 9600f6989a40c9180d00ccabbeb29cb37a28900a..05a955137f8dfe7aa085058c5a6673ce8f2f77f1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -63,9 +63,14 @@ class ServiceOptions { ServiceOptions& set_number_of_replicas(int number_of_replicas); int number_of_replicas() const; + // Sets the thread pool size for parallel execution of an individual operator. + ServiceOptions& set_intra_op_parallelism_threads(int num_threads); + int intra_op_parallelism_threads() const; + private: perftools::gputools::Platform* platform_ = nullptr; int number_of_replicas_ = -1; + int intra_op_parallelism_threads_ = -1; }; // The XLA service object, which is the same across all @@ -265,11 +270,11 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice arguments, const Backend* backend, int device_ordinal); - // Create a Hlo module config foe the given program shape and arguments. + // Create a Hlo module config for the given program shape and arguments. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + const ExecutionOptions& execution_options, Backend* backend); // Builds an Executable for the given parameters. If // executable_for_compute_constant is true, then the executable is intended to diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 0d4b214f5f3624971ae68e23f0f4fdba846f9178..017e5ef09ed2f52b862821e9408540d188a1edf5 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,10 +30,12 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function::SmartPtr>(int)>; - explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options, - StreamBorrower borrow_stream = nullptr) + explicit ServiceExecutableRunOptions( + ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) : run_options_(std::move(run_options)), - borrow_stream_(std::move(borrow_stream)) {} + borrow_stream_(std::move(borrow_stream)), + xla_intra_op_thread_pool_(xla_intra_op_thread_pool) {} // Returns reference or pointer to `ExecutableRunOptions` member. const ExecutableRunOptions& run_options() const { return run_options_; } @@ -53,9 +55,15 @@ class ServiceExecutableRunOptions { : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); } + // Returns reference to thread pool for execution of XLA ops on CPU backend. + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool() const { + return xla_intra_op_thread_pool_; + } + private: ExecutableRunOptions run_options_; StreamBorrower borrow_stream_; + tensorflow::thread::ThreadPool* xla_intra_op_thread_pool_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c05cf8c37d84f4120df344db939551e26a0355af..b2ef8ed486b5ab4643cb0e26fa6c18e1f3894a4b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -244,8 +244,11 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "cannot concatenate arrays with different ranks: %lld vs %lld", - ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape)); + "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "(%s)", + ShapeUtil::Rank(*arg_shape), + ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), + ShapeUtil::HumanString(*shape).c_str()); } if (arg_shape->element_type() != shape->element_type()) { return InvalidArgument( @@ -309,6 +312,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "the rank of the operand and the padding configuration do not match."); } + if (operand_shape.element_type() != padding_value_shape.element_type()) { + return InvalidArgument( + "the element types of the operands to pad do not match"); + } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { dimensions[i] = operand_shape.dimensions(i) + @@ -338,7 +345,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // Check if both element types are the same. if (lhs.element_type() != rhs.element_type()) { - return fail("element types mismatch"); + return fail("element types do not match"); } if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || @@ -633,26 +640,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); switch (operation) { case TRIOP_CLAMP: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) && - (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) { - return rhs; - } - if (ShapeUtil::Rank(rhs) == 0) { - if (ShapeUtil::Compatible(lhs, ehs)) { - return lhs; - } - return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs; - } - return Unimplemented("not yet implemented: %s, %s %s", - lhs.ShortDebugString().c_str(), - ehs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); case TRIOP_UPDATE: @@ -1332,6 +1320,41 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::PermuteDimensions(InversePermutation(dimensions), operand); } +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops. +/* static */ StatusOr ShapeInference::InferClampShape( + const Shape& min, const Shape& operand, const Shape& max) { + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + if (!ShapeUtil::SameElementType(min, operand) || + !ShapeUtil::SameElementType(max, operand)) { + return InvalidArgument("clamp op with different operand types: %s, %s, %s", + ShapeUtil::HumanString(min).c_str(), + ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::HumanString(max).c_str()); + } + if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && + (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + return operand; + } + if (ShapeUtil::IsScalar(operand)) { + if (ShapeUtil::Compatible(min, max)) { + return min; + } else if (ShapeUtil::IsScalar(min)) { + return max; + } else if (ShapeUtil::IsScalar(max)) { + return min; + } + } + return Unimplemented( + "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); +} + +// TODO(b/36794510): Make broadcast semantics more consistent, by supporting +// "degenerate" cases, as with binary elementwise ops, as well as scalar +// broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { if (!ShapeUtil::Compatible(on_true, on_false)) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index ced2f4d0017e26b8f6d54b78f240dedecdbc79f3..c2223423e9223ba8ad995212415f219eea48e2a6 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -190,6 +190,10 @@ class ShapeInference { BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Helper for inferring the shape of Clamp ops. + static StatusOr InferClampShape(const Shape& min, const Shape& operand, + const Shape& max); + // Helper for inferring the shape of Select ops. static StatusOr InferSelectShape(const Shape& pred, const Shape& on_true, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 5a1ae6b0024c65c9a451f1500146dc81408b8684..7cff042a48db436b3d165e8eaedc5a3f3c76b15e 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -20,12 +20,16 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace { +using ::testing::ContainsRegex; +using ::testing::HasSubstr; + class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. @@ -128,23 +132,21 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("operands to select must be the same shape")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("pred operand must have PRED")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH( - inferred_status_error3.status().error_message(), - testing::ContainsRegex("with non-scalar predicate with dimensionality")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( @@ -152,9 +154,101 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("pred operand must have PRED element type")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("pred operand must have PRED element type")); +} + +TEST_F(ShapeInferenceTest, ClampAllMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, + matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampAllScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandScalar) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMinMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampMaxMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampOperandMatrix) { + auto inferred_status = ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + ASSERT_IS_OK(inferred_status.status()); + ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); +} + +TEST_F(ShapeInferenceTest, ClampBadShapes) { + // Type mismatch + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) + .ok()); + // Dimension mismatch + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_64_, vector_32_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_64_, vector_32_) + .ok()); + ASSERT_FALSE( + ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, + vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape( + TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + .ok()); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { @@ -205,8 +299,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { operand_shape_, select_program_shape_, window_, source_shape_fail, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("source shape does not match")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -216,9 +310,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH( - inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must take 2 parameters")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -228,8 +321,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function must have rank-0 PRED")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -239,8 +332,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's first parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -250,8 +343,8 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); - ASSERT_MATCH(inferred_status_fail.status().error_message(), - testing::ContainsRegex("select function's second parameter")); + ASSERT_THAT(inferred_status_fail.status().error_message(), + HasSubstr("select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -405,8 +498,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { auto inferred_status = ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("each dimension exactly once")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("each dimension exactly once")); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { @@ -443,43 +536,42 @@ TEST_F(ShapeInferenceTest, Map) { auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(no_args_error.ok()); - ASSERT_MATCH(no_args_error.status().error_message(), - testing::ContainsRegex("expects at least one argument")); + ASSERT_THAT(no_args_error.status().error_message(), + HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(args_diff_shapes_error.ok()); - ASSERT_MATCH( - args_diff_shapes_error.status().error_message(), - testing::ContainsRegex("requires all operands to have the same shape")); + ASSERT_THAT(args_diff_shapes_error.status().error_message(), + HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); ASSERT_FALSE(arity_error.ok()); - ASSERT_MATCH(arity_error.status().error_message(), - testing::ContainsRegex("function arity must match")); + ASSERT_THAT(arity_error.status().error_message(), + HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); ASSERT_FALSE(output_shape_error.ok()); - ASSERT_MATCH(output_shape_error.status().error_message(), - testing::ContainsRegex("result has to be a scalar")); + ASSERT_THAT(output_shape_error.status().error_message(), + HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); ASSERT_FALSE(param_shape_error.ok()); - ASSERT_MATCH(param_shape_error.status().error_message(), - testing::ContainsRegex("parameter has to be a scalar")); + ASSERT_THAT(param_shape_error.status().error_message(), + HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); ASSERT_FALSE(param_element_type_error.ok()); - ASSERT_MATCH(param_element_type_error.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(param_element_type_error.status().error_message(), + HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); @@ -490,26 +582,26 @@ TEST_F(ShapeInferenceTest, Map) { auto inferred_status_error1 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match number of arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("has to be a scalar")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex("parameter type has to match argument")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("parameter type has to match argument")); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { @@ -563,8 +655,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("out-of-bounds dimension")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { @@ -573,8 +665,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("take 2 parameters")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { @@ -583,8 +675,8 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); - EXPECT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("first parameter shape differs")); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("first parameter shape differs")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { @@ -726,8 +818,8 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { auto inferred_status = ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("dot only supports rank")); } // 3D 2D: error @@ -735,8 +827,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { auto inferred_status = ShapeInference::InferBinaryOpShape( BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); ASSERT_FALSE(inferred_status.ok()); - ASSERT_MATCH(inferred_status.status().error_message(), - testing::ContainsRegex("dot only supports rank")); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("dot only supports rank")); } // vector vector -> scalar @@ -848,46 +940,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("automatic")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH( - inferred_status_error2.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( - inferred_status_error4.status().error_message(), - testing::ContainsRegex("size of broadcast_dimensions has to match")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("size of broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH( - inferred_status_error5.status().error_message(), - testing::ContainsRegex("broadcast dimension number .* too large")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + ContainsRegex("broadcast dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH(inferred_status_error6.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -895,14 +984,14 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); - ASSERT_MATCH(inferred_status_error7.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error7.status().error_message(), + HasSubstr("broadcast dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); - ASSERT_MATCH(inferred_status_error8.status().error_message(), - testing::ContainsRegex("broadcast dimensions order is wrong")); + ASSERT_THAT(inferred_status_error8.status().error_message(), + HasSubstr("broadcast dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -927,30 +1016,30 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("condition must take 1 arguments")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); auto inferred_status_error2 = ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("body must take 1 arguments")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("condition must return a boolean")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH(inferred_status_error4.status().error_message(), - testing::ContainsRegex("parameter of condition and body")); + ASSERT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("parameter of condition and body")); } // Tests for the concatenate instruction with proper shapes. @@ -980,49 +1069,44 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH( - inferred_status_error1.status().error_message(), - testing::ContainsRegex("Concatenate expects at least one argument")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("Concatenate expects at least one argument")); auto inferred_status_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: -1")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("dimension to concatenate along out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: 1")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("dimension to concatenate along out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( {&vector_32_, &tuple}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error4.ok()); - ASSERT_MATCH( + ASSERT_THAT( inferred_status_error4.status().error_message(), - testing::ContainsRegex( - "Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation.")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_MATCH(inferred_status_error5.status().error_message(), - testing::ContainsRegex( - "cannot concatenate arrays with different element types")); + ASSERT_THAT( + inferred_status_error5.status().error_message(), + HasSubstr("cannot concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); - ASSERT_MATCH( - inferred_status_error6.status().error_message(), - testing::ContainsRegex("cannot concatenate arrays that differ in " - "dimensions other than the one being " - "concatenated")); + ASSERT_THAT(inferred_status_error6.status().error_message(), + HasSubstr("cannot concatenate arrays that differ in " + "dimensions other than the one being " + "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { @@ -1063,27 +1147,27 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { auto inferred_status_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); ASSERT_FALSE(inferred_status_error0.ok()); - ASSERT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("out-of-bounds")); auto inferred_status_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); ASSERT_FALSE(inferred_status_error1.ok()); - ASSERT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("out-of-bounds")); + ASSERT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("out-of-bounds")); auto inferred_status_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); ASSERT_FALSE(inferred_status_error2.ok()); - ASSERT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("duplicated")); + ASSERT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("duplicated")); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); auto inferred_status_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); - ASSERT_MATCH(inferred_status_error3.status().error_message(), - testing::ContainsRegex("Expected non-tuple argument")); + ASSERT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("Expected non-tuple argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1103,20 +1187,20 @@ TEST_F(ShapeInferenceTest, Call) { auto inferred_status_error0 = ShapeInference::InferCallShape( {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); - EXPECT_MATCH(inferred_status_error0.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("arity must match")); auto inferred_status_error1 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_FALSE(inferred_status_error1.ok()); - EXPECT_MATCH(inferred_status_error1.status().error_message(), - testing::ContainsRegex("arity must match")); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("arity must match")); auto inferred_status_error2 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); EXPECT_FALSE(inferred_status_error2.ok()); - EXPECT_MATCH(inferred_status_error2.status().error_message(), - testing::ContainsRegex("parameter must match argument")); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index b052bb814693c2e9364c94154ca223fe98526622..83e893a14a6d95e3741af57d34eadef4e5c088d9 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -99,13 +99,6 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) = 0; - // Returns whether tuple elements are distinct buffers (in which case each of - // the elements of a tuple should be deallocated, in addition to the tuple's - // buffer itself). - // - // TODO(b/36256956) Ideally tuple elements could always be distinct buffers. - virtual bool TupleElementsAreDistinctBuffers() const { return true; } - // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 07e0ce89f6ad2ba194832096de2399ab618422a4..a0c88c6bbc23972bb6a0f3729e51ee0eaee72bc7 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -21,7 +21,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" @@ -30,43 +32,55 @@ namespace xla { namespace { -bool IsOperandFoldableToDot(const HloInstruction& hlo) { - return hlo.IsRank2Transpose() && - hlo.user_count() == 1; // The dot is its only user. -} - -bool CanFoldOperandsIntoDot( +TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, - const TransposeFolding::IsTransposableGemmFn& is_transposable_gemm) { + const TransposeFolding::TransposableGemmOperandsFn& + transposable_gemm_operands) { if (HloOpcode::kDot != dot.opcode()) { - return false; + return {}; } - if (!is_transposable_gemm(dot)) { - return false; + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < dot.operand_count(); ++i) { + auto& operand = *dot.operand(i); + if (operand.IsRank2Transpose() && operand.user_count() == 1) { + operand_set.push_back(i); + } } - const HloInstruction* lhs = dot.operand(0); - const HloInstruction* rhs = dot.operand(1); - bool lhs_foldable = IsOperandFoldableToDot(*lhs); - bool rhs_foldable = IsOperandFoldableToDot(*rhs); - if (!lhs_foldable && !rhs_foldable) { - return false; + return transposable_gemm_operands(dot, operand_set); +} + +TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( + const HloInstruction& convolution, + const TransposeFolding::TransposableConvOperandsFn& + transposable_conv_operands) { + if (HloOpcode::kConvolution != convolution.opcode()) { + return {}; } - return true; + + // We only support folding the RHS. + const int64 kRhsOperandIndex = 1; + auto& operand = *convolution.operand(kRhsOperandIndex); + if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { + return transposable_conv_operands(convolution, {kRhsOperandIndex}); + } + + return {}; } +using InstructionOperandsPair = + std::pair; + // Folds the operands of `dot` that are foldable transposes. `computation` is -// the parent HLO computation of `dot`. `module` is the parent HloModule of -// `computation`. +// the parent HLO computation of `dot`. // // Returns whether the module is changed. -bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) { +bool FoldTransposeIntoDot(InstructionOperandsPair pair) { + auto* dot = pair.first; std::vector instructions_to_fuse(1, dot); - for (HloInstruction* operand : dot->operands()) { - if (IsOperandFoldableToDot(*operand)) { - instructions_to_fuse.push_back(operand); - } + for (const int64 operand_index : pair.second) { + instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); } // Early-exit if no operands are foldable. @@ -74,33 +88,100 @@ bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) { return false; } - computation->CreateFusionInstruction( + dot->parent()->CreateFusionInstruction( instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); return true; } +// Folds the operands of `convolution` that are foldable transposes. +// `computation` is the parent HLO computation of `convolution`. +// +// Returns whether the module is changed. +bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { + auto& convolution = *pair.first; + + // We only support fusing the RHS transpose into convolution. + // + // ConvolutionDimensionNumbers doesn't make enough of a distinction between + // the output and the activations. + // + // TODO(b/37125184): Support transposing the LHS too. + if (pair.second.size() != 1 || pair.second.front() != 1) { + return false; + } + + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + HloInstruction& transpose = *convolution.mutable_operand(1); + CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure out + // what the new logical dimensions are. + ConvolutionDimensionNumbers new_dnums = dnums; + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + + auto new_conv = HloInstruction::CreateConvolve( + convolution.shape(), convolution.mutable_operand(0), &transpose_operand, + convolution.window(), new_dnums); + TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( + &convolution, std::move(new_conv))); + + return true; +} + } // namespace -TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm) - : is_transposable_gemm_(std::move(is_transposable_gemm)) {} +TransposeFolding::TransposeFolding( + TransposableGemmOperandsFn transposable_gemm_operands, + TransposableConvOperandsFn transposable_conv_operands) + : transposable_gemm_operands_(std::move(transposable_gemm_operands)), + transposable_conv_operands_(std::move(transposable_conv_operands)) {} StatusOr TransposeFolding::Run(HloModule* module) { // Modifying the graph while traversing is dangerous, so we find all folding // opportunities before actually folding them. - HloComputation* entry_computation = module->entry_computation(); - - std::vector foldable_dots; - auto visit_fn = [this, &foldable_dots](HloInstruction* instruction) { - if (CanFoldOperandsIntoDot(*instruction, is_transposable_gemm_)) { - foldable_dots.emplace_back(instruction); + std::vector> foldable_dots; + std::vector> foldable_convolutions; + auto visit_fn = [this, &foldable_dots, + &foldable_convolutions](HloInstruction* instruction) { + { + OperandIndices operand_indices = + CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_); + if (!operand_indices.empty()) { + foldable_dots.emplace_back(instruction, operand_indices); + } + } + { + OperandIndices operand_indices = CanFoldOperandsIntoConvolution( + *instruction, transposable_conv_operands_); + if (!operand_indices.empty()) { + foldable_convolutions.emplace_back( + std::make_pair(instruction, operand_indices)); + } } return tensorflow::Status::OK(); }; - TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn)); + + for (auto& comp : module->computations()) { + TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); + } bool changed = false; - for (HloInstruction* dot : foldable_dots) { - changed |= FoldTransposeIntoDot(dot, entry_computation); + for (InstructionOperandsPair& pair : foldable_dots) { + changed |= FoldTransposeIntoDot(pair); + } + for (InstructionOperandsPair& pair : foldable_convolutions) { + changed |= FoldTransposeIntoConvolution(pair); } return changed; } diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index d857c04ed8d0c0d9d6c005db0f29ab0c5abd3bb2..71e8446452f072c22bb730cbda65a1743a95cd4c 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -25,16 +25,37 @@ namespace xla { // operator is implemented by a GEMM kernel that can transpose its inputs. class TransposeFolding : public HloPassInterface { public: - // IsTransposableGemmFn should return true iff the instruction argument is - // implemented as a GEMM kernel that supports transposing its arguments. - typedef std::function IsTransposableGemmFn; - explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm); + using OperandIndices = std::vector; + + // Returns the set of foldable operands for a given HLO and some candidate + // operands. + using FoldableOperands = std::function; + using TransposableGemmOperandsFn = FoldableOperands; + using TransposableConvOperandsFn = FoldableOperands; + + // Helper function to explicitly not fold transposes. + static OperandIndices NeverFoldTranspose(const HloInstruction&, + const OperandIndices&) { + return {}; + } + // transposable_gemm_operands returns the set of operands it wants to fold if + // the instruction argument is implemented as a GEMM kernel that supports + // transposing its arguments. + // + // transposable_conv_operands returns the set of operands it wants to fold if + // the instruction argument is implemented as a convolution that supports + // transposing its arguments. + explicit TransposeFolding( + TransposableGemmOperandsFn transposable_gemm_operands, + TransposableConvOperandsFn transposable_conv_operands); tensorflow::StringPiece name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; private: - IsTransposableGemmFn is_transposable_gemm_; + TransposableGemmOperandsFn transposable_gemm_operands_; + TransposableConvOperandsFn transposable_conv_operands_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 09f932e29e61a24b178e7ced0d2643aa484bea02..c72d127ea86e4e9daf99dff4335c538c081f0605 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -16,16 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include -#include +#include #include +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -35,12 +38,20 @@ namespace xla { class TransposeFoldingTest : public ::testing::Test { protected: void FoldTranspose(HloModule* module) { - TransposeFolding transpose_folding(gpu::ImplementedAsGemm); + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); EXPECT_IS_OK(transpose_folding.Run(module).status()); } }; -TEST_F(TransposeFoldingTest, FoldTranspose) { +TEST_F(TransposeFoldingTest, FoldDotTranspose) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), @@ -61,7 +72,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::set instruction_set; + std::unordered_set instruction_set; for (auto& instruction : entry_computation->instructions()) { instruction_set.insert(instruction.get()); } @@ -77,7 +88,7 @@ TEST_F(TransposeFoldingTest, FoldTranspose) { EXPECT_EQ(4, fusion->fused_instructions().size()); } -TEST_F(TransposeFoldingTest, FoldTransposeConstant) { +TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { auto builder = HloComputation::Builder("entry_computation"); // 2x1 HloInstruction* const0 = builder.AddInstruction( @@ -115,7 +126,7 @@ TEST_F(TransposeFoldingTest, FoldTransposeConstant) { entry_computation->root_instruction()->fused_instructions().size()); } -TEST_F(TransposeFoldingTest, FuseWithConstantOperands) { +TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( @@ -139,11 +150,219 @@ TEST_F(TransposeFoldingTest, FuseWithConstantOperands) { EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. - EXPECT_MATCH(call->operands(), testing::UnorderedMatcher( - const1, const2, const3)); + EXPECT_THAT(call->operands(), + ::testing::UnorderedElementsAre(const1, const2, const3)); // The callee should contain 3 parameters and 3 binary operators. EXPECT_EQ(6, callee_computation->instructions().size()); } +TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, + /*lhs=*/x, /*rhs=*/transpose_y)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(dot)); + + HloInstruction* call = module.OutlineExpressionFromComputation( + {transpose_y, dot}, "outlined", entry_computation); + + FoldTranspose(&module); + + // Instructions after folding: x, y, and the fusion. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(call)) + << "call is not in entry_computation."; + CHECK(instruction_set.empty()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* fusion = + call->called_computations().front()->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + + // The fusion instruction should contain two parameters, one transpose and + // one dot. + EXPECT_EQ(4, fusion->fused_instructions().size()); +} + +// Test that a two dimension swap of the kernel gets folded into convolution. +TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size( + transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + x->shape(), transpose_y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.kernel_input_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_output_feature_dimension()); + EXPECT_EQ(dnums.kernel_output_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_input_feature_dimension()); +} + +// Test that a complex transpose of the kernel gets folded into convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {1, 2, 1, 3}), + /*name=*/"y")); + HloInstruction* transpose_y = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size( + transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + x->shape(), transpose_y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.kernel_input_feature_dimension(), + new_conv->convolution_dimension_numbers() + .kernel_output_feature_dimension()); + EXPECT_EQ(dnums.kernel_spatial_dimensions(1), + new_conv->convolution_dimension_numbers() + .kernel_input_feature_dimension()); + EXPECT_EQ( + dnums.kernel_output_feature_dimension(), + new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(0)); + EXPECT_EQ( + dnums.kernel_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); +} + +// Test that a transpose of the activations does not get folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: transpose_x, y, and the convolution. + std::unordered_set instruction_set; + for (auto& instruction : entry_computation->instructions()) { + instruction_set.insert(instruction.get()); + } + CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(transpose_x)) + << "transpose_x is not in entry_computation."; + CHECK_EQ(1, instruction_set.erase(conv)) + << "transpose_x is not in entry_computation."; + CHECK_EQ(0, instruction_set.size()) + << "entry_computation should contain exactly 4 instructions."; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 98c51b48f9022c5f2d1e23b59a6ce775f3a48e0b..554adaf0e32f7cb896e07a59d5235ff84a11bb92 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -131,10 +131,9 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, } /* static */ StatusOr> -TuplePointsToAnalysis::Run(const HloModule* module, - const bool include_loop_fusion_instructions) { +TuplePointsToAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( - new TuplePointsToAnalysis(module, include_loop_fusion_instructions)); + new TuplePointsToAnalysis(module)); TF_RETURN_IF_ERROR(analysis->Analyze()); return std::move(analysis); } @@ -145,17 +144,14 @@ Status TuplePointsToAnalysis::Analyze() { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - if (include_loop_fusion_instructions_) { - // Run points-to analysis on loop fusion instructions in 'computation'. - for (auto& instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion || - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) { - continue; - } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR(PopulateDefinedBuffersAndAliases( - instruction->fused_instructions())); + // Run points-to analysis on fusion instructions in 'computation'. + for (auto& instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kFusion) { + continue; } + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } @@ -482,9 +478,7 @@ string TuplePointsToAnalysis::ToString() const { for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); - if (include_loop_fusion_instructions_ && - instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (instruction->opcode() == HloOpcode::kFusion) { for (auto& fused : instruction->fused_instructions()) { InstructionToString(fused.get(), &output); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index a384529171a7371c848ca8949d22cb6717d83a78..85a71b56ce5e9fb1a3441c302e18bd1fa7b68864 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -148,12 +148,9 @@ std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); // the potential sources of each buffer in each instruction's output. class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { public: - // Runs points-to analysis on 'module'. If 'include_loop_fusion_instructions' - // is true, includes fused instructions from each loop fusion instruction - // in 'module' in the points-to analysis. + // Runs points-to analysis on 'module'. static StatusOr> Run( - const HloModule* module, - const bool include_loop_fusion_instructions = false); + const HloModule* module); // Return the points-to set of an instruction. This describes the potential // sources of each buffer in the instruction's output. @@ -218,10 +215,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; private: - explicit TuplePointsToAnalysis(const HloModule* module, - const bool include_loop_fusion_instructions) - : module_(module), - include_loop_fusion_instructions_(include_loop_fusion_instructions) {} + explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {} // Perform the analysis. Should be called immediately after constructing the // object and before calling GetPointsToSet. @@ -261,9 +255,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // The module this analysis is performed on. const HloModule* module_; - // Whether to run points-to analysis on loop fusion instructions in 'module_'. - const bool include_loop_fusion_instructions_; - // A map containing a PointsToSet for every HLO instruction. tensorflow::gtl::FlatMap> points_to_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 4a4a6e64ffae265bc143cfd7adb9f7d53b2b0359..87e1b058b79c0dc327cc1ad63a8cffa97c190df4 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -19,18 +19,25 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { +using ::testing::UnorderedElementsAreArray; +using ::testing::UnorderedElementsAre; + class TuplePointsToAnalysisTest : public HloTestBase { protected: // Builds a module with the given entry computation and runs points to @@ -45,11 +52,10 @@ class TuplePointsToAnalysisTest : public HloTestBase { module_->AddEntryComputation(std::move(computation)); } - void RunAnalysis(const bool include_loop_fusion_instructions = false) { + void RunAnalysis() { CHECK_NOTNULL(module_.get()); - points_to_analysis_ = TuplePointsToAnalysis::Run( - module_.get(), include_loop_fusion_instructions) - .ConsumeValueOrDie(); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); } // Returns the LogicalBuffer defined at the given instruction and @@ -70,7 +76,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { const std::vector& points_to_set, tensorflow::gtl::ArraySlice buffers) { std::vector vec(buffers.begin(), buffers.end()); - EXPECT_MATCH(points_to_set, testing::UnorderedElementsAre(vec)); + EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } // Checks that the given points-to set contains exactly (unordered) the @@ -107,20 +113,14 @@ class TuplePointsToAnalysisTest : public HloTestBase { for (auto& pair : expected) { expected_aliases.push_back(BufferAlias(*buffer, pair.first, pair.second)); } - EXPECT_MATCH(points_to_analysis_->GetBufferAliases(*buffer), - testing::UnorderedElementsAre(expected_aliases)); + EXPECT_THAT(points_to_analysis_->GetBufferAliases(*buffer), + UnorderedElementsAreArray(expected_aliases)); } std::unique_ptr module_; std::unique_ptr points_to_analysis_; }; -// Expect the given std::set as A contains exactly the given -// HloInstruction*s as __VA_ARGS__. -#define EXPECT_ISET(A, ...) \ - EXPECT_MATCH(testing::SetToVec(A), \ - testing::UnorderedMatcher(__VA_ARGS__)) - TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( @@ -146,8 +146,8 @@ TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), @@ -205,9 +205,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(inner_tuple).element({}), {inner_tuple}); - EXPECT_ISET( + EXPECT_THAT( points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}), - inner_tuple); + UnorderedElementsAre(inner_tuple)); EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); @@ -215,10 +215,10 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), {constant1, constant2, constant3, inner_tuple, tuple}); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), - inner_tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}), + UnorderedElementsAre(inner_tuple)); EXPECT_TRUE( points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty()); @@ -262,7 +262,8 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { {constant1, constant2, inner_tuple}); ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple}); - EXPECT_ISET(points_to_set.tuple_sources({}), inner_tuple); + EXPECT_THAT(points_to_set.tuple_sources({}), + UnorderedElementsAre(inner_tuple)); } TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { @@ -460,8 +461,10 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2}); // Verify tuple sources. - EXPECT_ISET(points_to_set.tuple_sources({}), tuple1, tuple2); - EXPECT_ISET(points_to_set.tuple_sources({0}), inner_tuple1, inner_tuple2); + EXPECT_THAT(points_to_set.tuple_sources({}), + UnorderedElementsAre(tuple1, tuple2)); + EXPECT_THAT(points_to_set.tuple_sources({0}), + UnorderedElementsAre(inner_tuple1, inner_tuple2)); EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size()); EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size()); } @@ -489,8 +492,8 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size()); EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous()); - EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), - tuple); + EXPECT_THAT(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}), + UnorderedElementsAre(tuple)); ExpectHasTopLevelBuffers( points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(), @@ -603,9 +606,9 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { .ValueOrDie()); // Get computation root instruction (should be a kFusion). auto* fusion = module_->entry_computation()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + EXPECT_THAT(fusion, op::Fusion(tuple_param0)); // Run points-to analysis (should include fused instructions from 'fusion'). - RunAnalysis(/*include_loop_fusion_instructions=*/true); + RunAnalysis(); // Check points-to set of fusion parameter associated with 'tuple_param0'. auto* fusion_param = GetFusionParameterForOperand(fusion, tuple_param0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index a77788e0b63b984328c0ea52ebbb94cb8583e6e3..e9fcc9fa6666bb2e3c24252e1c0f5e8d763a5d48 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1510,6 +1510,7 @@ void ConstantVisitor(const SessionComputation& session_computation, is_constant); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + *is_constant = false; break; } @@ -1927,6 +1928,12 @@ HloInstruction* ComputationLowerer::Visit( const OperationRequest& request = session_computation_.requests().at(handle.handle()); + auto add_instruction = [&](std::unique_ptr instruction) { + HloInstruction* hlo_instruction = + hlo_builder_.AddInstruction(std::move(instruction)); + hlo_instruction->set_metadata(request.request().metadata()); + return hlo_instruction; + }; HloInstruction* hlo_instruction; switch (request.request().op_case()) { case OpRequest::kRngRequest: { @@ -1935,7 +1942,7 @@ HloInstruction* ComputationLowerer::Visit( for (const ComputationDataHandle& param : rng_request.parameter()) { parameters.push_back(Visit(param, visited)); } - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng( + hlo_instruction = add_instruction(HloInstruction::CreateRng( request.output_shape(), rng_request.distribution(), parameters)); break; } @@ -1943,9 +1950,8 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kConstantRequest: { const ConstantRequest& constant_request = request.request().constant_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(constant_request.literal()))); + hlo_instruction = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CloneToUnique(constant_request.literal()))); break; } @@ -1954,17 +1960,15 @@ HloInstruction* ComputationLowerer::Visit( request.request().get_tuple_element_request(); HloInstruction* operand = Visit(get_tuple_element_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, - get_tuple_element_request.index())); + hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( + request.output_shape(), operand, get_tuple_element_request.index())); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); HloInstruction* operand = Visit(slice_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice( + hlo_instruction = add_instruction(HloInstruction::CreateSlice( request.output_shape(), operand, AsInt64Slice(slice_request.start_indices()), AsInt64Slice(slice_request.limit_indices()))); @@ -1978,10 +1982,9 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* start_indices = Visit(dynamic_slice_request.start_indices(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); + hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( + request.output_shape(), operand, start_indices, + AsInt64Slice(dynamic_slice_request.slice_sizes()))); break; } @@ -1995,7 +1998,7 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* start_indices = Visit(dynamic_update_slice_request.start_indices(), visited); hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + add_instruction(HloInstruction::CreateDynamicUpdateSlice( request.output_shape(), operand, update, start_indices)); break; } @@ -2009,9 +2012,8 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* operand = Visit(handle, visited); operands.push_back(operand); } - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateConcatenate(request.output_shape(), operands, - concatenate_request.dimension())); + hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( + request.output_shape(), operands, concatenate_request.dimension())); break; } @@ -2020,10 +2022,9 @@ HloInstruction* ComputationLowerer::Visit( request.request().convolve_request(); HloInstruction* lhs = Visit(convolve_request.lhs(), visited); HloInstruction* rhs = Visit(convolve_request.rhs(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); + hlo_instruction = add_instruction(HloInstruction::CreateConvolve( + request.output_shape(), lhs, rhs, convolve_request.window(), + convolve_request.dimension_numbers())); break; } @@ -2032,17 +2033,15 @@ HloInstruction* ComputationLowerer::Visit( request.request().cross_replica_sum_request(); HloInstruction* operand = Visit(cross_replica_sum_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( + request.output_shape(), operand)); break; } case OpRequest::kInfeedRequest: { const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); + hlo_instruction = add_instruction(HloInstruction::CreateInfeed( + request.output_shape(), infeed_request.config())); break; } @@ -2050,9 +2049,8 @@ HloInstruction* ComputationLowerer::Visit( const OutfeedRequest& outfeed_request = request.request().outfeed_request(); HloInstruction* operand = Visit(outfeed_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateOutfeed(outfeed_request.shape(), operand, - outfeed_request.outfeed_config())); + hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( + outfeed_request.shape(), operand, outfeed_request.outfeed_config())); break; } @@ -2068,7 +2066,7 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* map_computation = ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap( + hlo_instruction = add_instruction(HloInstruction::CreateMap( request.output_shape(), operands, map_computation)); break; } @@ -2082,10 +2080,9 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* reduce_computation = ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduce( + request.output_shape(), operand, init_value, + AsInt64Slice(reduce_request.dimensions()), reduce_computation)); break; } @@ -2100,10 +2097,9 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* reduce_window_computation = ResolveComputation( reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( + request.output_shape(), operand, init_value, + reduce_window_request.window(), reduce_window_computation)); break; } @@ -2125,11 +2121,10 @@ HloInstruction* ComputationLowerer::Visit( select_and_scatter_request.select(), select_version); HloComputation* scatter_computation = ResolveComputation( select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); + hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( + request.output_shape(), operand, select_computation, + select_and_scatter_request.window(), source, init_value, + scatter_computation)); break; } @@ -2150,9 +2145,8 @@ HloInstruction* ComputationLowerer::Visit( ShapeUtil::Rank(request.output_shape()) - ShapeUtil::Rank(operand->shape())); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); + hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( + request.output_shape(), operand, broadcast_dimensions)); break; } @@ -2164,14 +2158,13 @@ HloInstruction* ComputationLowerer::Visit( if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { transposed = operand; } else { - transposed = - hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( - reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); + transposed = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(reshape_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(reshape_request.dimensions()))); } - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateReshape(request.output_shape(), transposed)); break; } @@ -2180,12 +2173,11 @@ HloInstruction* ComputationLowerer::Visit( const TransposeRequest& transpose_request = request.request().transpose_request(); HloInstruction* operand = Visit(transpose_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice( - transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); + hlo_instruction = add_instruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions( + InversePermutation(AsInt64Slice(transpose_request.dimensions())), + operand->shape()), + operand, AsInt64Slice(transpose_request.dimensions()))); break; } @@ -2193,10 +2185,9 @@ HloInstruction* ComputationLowerer::Visit( const ReverseRequest& reverse_request = request.request().reverse_request(); HloInstruction* operand = Visit(reverse_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); + hlo_instruction = add_instruction(HloInstruction::CreateReverse( + request.output_shape(), operand, + AsInt64Slice(reverse_request.dimensions()))); break; } @@ -2205,7 +2196,7 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* operand = Visit(pad_request.operand(), visited); HloInstruction* padding_value = Visit(pad_request.padding_value(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad( + hlo_instruction = add_instruction(HloInstruction::CreatePad( request.output_shape(), operand, padding_value, pad_request.padding_config())); break; @@ -2213,7 +2204,7 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRecv( + hlo_instruction = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); break; } @@ -2221,10 +2212,9 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kParameterRequest: { const ParameterRequest& parameter_request = request.request().parameter_request(); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); + hlo_instruction = add_instruction(HloInstruction::CreateParameter( + parameter_request.parameter(), request.output_shape(), + parameter_request.name())); break; } @@ -2232,7 +2222,7 @@ HloInstruction* ComputationLowerer::Visit( const ConvertRequest& convert_request = request.request().convert_request(); HloInstruction* operand = Visit(convert_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateConvert(request.output_shape(), operand)); break; } @@ -2249,7 +2239,7 @@ HloInstruction* ComputationLowerer::Visit( HloComputation* body = ResolveComputation(while_request.body(), body_version); HloInstruction* init = Visit(while_request.init(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile( + hlo_instruction = add_instruction(HloInstruction::CreateWhile( request.output_shape(), condition, body, init)); break; } @@ -2261,9 +2251,8 @@ HloInstruction* ComputationLowerer::Visit( HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited); HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); + hlo_instruction = add_instruction(HloInstruction::CreateTernary( + request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; } @@ -2278,9 +2267,8 @@ HloInstruction* ComputationLowerer::Visit( } auto hlo_opcode = VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); + hlo_instruction = add_instruction(HloInstruction::CreateVariadic( + request.output_shape(), hlo_opcode, operands)); break; } @@ -2295,7 +2283,7 @@ HloInstruction* ComputationLowerer::Visit( request.embedded_computation_versions(0); HloComputation* call_computation = ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall( + hlo_instruction = add_instruction(HloInstruction::CreateCall( request.output_shape(), operands, call_computation)); break; } @@ -2307,9 +2295,8 @@ HloInstruction* ComputationLowerer::Visit( for (const ComputationDataHandle& operand : cc_request.operands()) { operands.push_back(Visit(operand, visited)); } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); + hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( + cc_request.shape(), operands, cc_request.call_target_name())); break; } @@ -2318,7 +2305,7 @@ HloInstruction* ComputationLowerer::Visit( request.request().unary_op_request(); HloInstruction* operand = Visit(unary_op_request.operand(), visited); auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary( + hlo_instruction = add_instruction(HloInstruction::CreateUnary( request.output_shape(), hlo_opcode, operand)); break; } @@ -2346,23 +2333,22 @@ HloInstruction* ComputationLowerer::Visit( // identical to the HLO broadcast semantics so the broadcast_dimensions // field can just be passed to the instruction builder. HloInstruction* broadcasted_operand = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + add_instruction(HloInstruction::CreateBroadcast( broadcast_shape, operand_to_broadcast, AsInt64Slice(binary_op_request.broadcast_dimensions()))); lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); + hlo_instruction = add_instruction(HloInstruction::CreateBinary( + request.output_shape(), hlo_opcode, lhs, rhs)); break; } case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); HloInstruction* operand = Visit(trace_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction( + hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); operand->set_tracing(hlo_instruction); break; @@ -2371,7 +2357,7 @@ HloInstruction* ComputationLowerer::Visit( case OpRequest::kSendRequest: { const SendRequest& send_request = request.request().send_request(); HloInstruction* operand = Visit(send_request.operand(), visited); - hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSend( + hlo_instruction = add_instruction(HloInstruction::CreateSend( operand, send_request.channel_handle().handle())); break; } @@ -2382,7 +2368,6 @@ HloInstruction* ComputationLowerer::Visit( default: LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } - hlo_instruction->set_metadata(request.request().metadata()); (*visited)[handle.handle()] = hlo_instruction; return hlo_instruction; } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index e67254328ad6973ee63a83d45cd3c2618e39ff56..cf04cfde5003d70e26ce0a1543039c18c19282c9 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -17,12 +17,16 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -55,6 +59,9 @@ TEST_F(UserComputationTest, SimpleComputation) { param_request.set_name("param0"); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle param_handle, computation.AddParameterInstruction(param_request)); + OpMetadata metadata; + metadata.set_op_name("meta"); + TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); OutfeedRequest outfeed_request; *outfeed_request.mutable_operand() = constant_handle; @@ -89,8 +96,7 @@ TEST_F(UserComputationTest, SimpleComputation) { EXPECT_EQ(3, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the // outfeed). - EXPECT_EQ(HloOpcode::kParameter, - hlo_computation->root_instruction()->opcode()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); } { @@ -114,8 +120,7 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.BuildHloComputation( version_at_param.version, hlo_resolver)); EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_EQ(HloOpcode::kParameter, - hlo_computation->root_instruction()->opcode()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); } { // Test the computation at the latest version, but lowered with @@ -132,8 +137,9 @@ TEST_F(UserComputationTest, SimpleComputation) { EXPECT_EQ(1, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the // outfeed). - EXPECT_EQ(HloOpcode::kParameter, - hlo_computation->root_instruction()->opcode()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); + EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), + "meta"); } } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 2159386152b34e4f9b59ca14faa756e37551d724..c8851d2ca512450b4022e0f70d55399323b2fa08 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -21,7 +21,10 @@ limitations under the License. namespace xla { -// Defines the interface for an XLA service. +// Defines the interface for an XLA service on the client side. This service +// helps abstract around the actual implementation of a service - the service +// can be local (running in the same process), or remote - in which case an RPC +// stub is used as the implementation. class ServiceInterface { public: ServiceInterface() {} diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 6963a68d10d527acebde65f30f9caf87608950cb..aa4341d18e1e6ef0dba5a4bcc057d9ef43d9bfb0 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -33,22 +33,65 @@ limitations under the License. namespace xla { +namespace internal { + +// Internal representation of each node in a ShapeTree. +template +struct ShapeTreeNode { + // Data corresponding to this node. + T data; + + // Children of this node. + std::vector> children; + + explicit ShapeTreeNode(const T& data) : data(data) {} + + ShapeTreeNode(const ShapeTreeNode& other) + : data(other.data), children(other.children.size()) { + for (size_t i = 0; i < children.size(); ++i) { + children[i] = MakeUnique(*other.children[i]); + } + } + + ShapeTreeNode& operator=(const ShapeTreeNode& other) { + if (this != &other) { + data = other.data; + children.resize(other.children.size()); + for (size_t i = 0; i < children.size(); ++i) { + children[i] = MakeUnique(*other.children[i]); + } + } + return *this; + } +}; + +} // namespace internal + // A ShapeTree is a recursive data structure which mirrors the structure of a -// XLA shape and holds a value of type T for each array in the shape. For -// array shapes, a ShapeTree trivially holds a single value of type T. For tuple -// shapes which can be an arbitrary tree with arrays at the leaves, a ShapeTree -// is an identically structured tree with data elements of type T at the leaves. +// XLA shape and holds a value of type T for each subshape (i.e. tuple or array) +// in the shape. For array shapes, a ShapeTree trivially holds a single value of +// type T. +// +// For tuple shapes which can be an arbitrary tree with arrays at the leaves, a +// ShapeTree is an identically structured tree with data elements of type T at +// every node. I.e. the root is a tuple by definition, all interior nodes are +// also tuples, and all leaves are arrays. // // Like the Shape data structure, this is a tree and tuple elements cannot be -// duplicated. That is, every distinct element position in the Shape has a -// unique T object. +// duplicated. That is, every distinct ShapeIndex in the Shape has a unique T +// object. template class ShapeTree { public: - explicit ShapeTree(const Shape& shape); + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). + ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} + // Create ShapeTree with the given shape, and default T values for all nodes. + explicit ShapeTree(const Shape& shape) : ShapeTree(shape, T()) {} + // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(const Shape& shape, const T& init_value); - ShapeTree(const ShapeTree& other); - ShapeTree& operator=(const ShapeTree& other); + + ShapeTree(const ShapeTree& other) = default; + ShapeTree& operator=(const ShapeTree& other) = default; // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -56,12 +99,12 @@ class ShapeTree { T* mutable_element(const ShapeIndex& index); // Return the shape represented with this ShapeTree. - const Shape& shape() const { return *shape_; } + const Shape& shape() const { return shape_; } // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index).elements_.empty(); + return Lookup(index)->children.empty(); } // Recursively traverses the shape and calls the given function at each @@ -76,183 +119,125 @@ class ShapeTree { // // If any call to the given function returns a non-OK status, then traversal // is aborted and the status value is returned. - using VisitorFunction = std::function; - tensorflow::Status ForEachElement(VisitorFunction func) const; + Status ForEachElement(const VisitorFunction& func) const; - using MutableVisitorFunction = std::function; - tensorflow::Status ForEachMutableElement(MutableVisitorFunction func); + Status ForEachMutableElement(const MutableVisitorFunction& func); private: - // Private default constructor for non-root nodes of the tree. - ShapeTree() = default; + using Node = internal::ShapeTreeNode; + + // Initialize node->children based on 'shape'. All children are assigned the + // the given 'init_value'. + void InitChildren(const Shape& shape, const T& init_value, Node* node); // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). - static tensorflow::Status ForEachHelperMutable(ShapeIndex* index, - ShapeTree* shape_tree, - MutableVisitorFunction func); - static tensorflow::Status ForEachHelper(ShapeIndex* index, - const ShapeTree& shape_tree, - VisitorFunction func); - - // Copy all the data elements (of type T) from "other" into "this". "this" - // must have the same tree structure as "other" prior to calling this method. - void CopyDataElements(const ShapeTree& other); - - // Recursive helper for constructing a subtree beneath "this" node. - void BuildTree(const Shape& shape); + static Status ForEachHelper(const VisitorFunction& func, const Node& node, + ShapeIndex* index); + static Status ForEachMutableHelper(const MutableVisitorFunction& func, + Node* node, ShapeIndex* index); // Return the tree node at the given index. - ShapeTree& Lookup(const ShapeIndex& index); - const ShapeTree& Lookup(const ShapeIndex& index) const; - - // The data corresponding to the array at this node. - T data_; + Node* Lookup(const ShapeIndex& index); + const Node* Lookup(const ShapeIndex& index) const; - // The XLA shape mirrored in this ShapeTree. Only the root of the - // ShapeTree has this member set. - std::unique_ptr shape_; + // The root node, which contains all other nodes. + Node root_; - // The children of this node in the tree. - std::vector> elements_; + // The XLA shape mirrored in this ShapeTree. + Shape shape_; }; template -void ShapeTree::BuildTree(const Shape& shape) { +void ShapeTree::InitChildren(const Shape& shape, const T& init_value, + Node* node) { if (ShapeUtil::IsTuple(shape)) { for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - elements_.emplace_back(new ShapeTree()); - elements_.back()->BuildTree(shape.tuple_shapes(i)); + node->children.emplace_back(new Node(init_value)); + InitChildren(shape.tuple_shapes(i), init_value, + node->children.back().get()); } } } -template -ShapeTree::ShapeTree(const Shape& shape) : shape_(MakeUnique(shape)) { - // The shape_ field is just used to hold the structure of the shape. It should - // not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); -} - template ShapeTree::ShapeTree(const Shape& shape, const T& init_value) - : shape_(MakeUnique(shape)) { - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); - TF_CHECK_OK(ForEachMutableElement( - [&init_value](const ShapeIndex& /*index*/, bool /*is_leaf*/, bool* data) { - *data = init_value; - return tensorflow::Status::OK(); - })); -} - -template -ShapeTree::ShapeTree(const ShapeTree& other) - : shape_(MakeUnique(other.shape())) { - LayoutUtil::ClearLayout(shape_.get()); - BuildTree(*shape_); - CopyDataElements(other); -} - -template -ShapeTree& ShapeTree::operator=(const ShapeTree& other) { - if (this == &other) { - return *this; - } - elements_.clear(); - shape_ = MakeUnique(other.shape()); - LayoutUtil::ClearLayout(shape_.get()); - - BuildTree(*shape_); - CopyDataElements(other); - return *this; -} - -template -void ShapeTree::CopyDataElements(const ShapeTree& other) { - CHECK(ShapeUtil::Compatible(shape(), other.shape())); - TF_CHECK_OK(ForEachMutableElement( - [&other](const ShapeIndex& index, bool /*is_leaf*/, T* data) { - *data = other.element(index); - return tensorflow::Status::OK(); - })); + : root_(init_value), shape_(shape) { + // The shape_ field is just used to hold the structure of the shape. + // It should not be relied upon to store layout information. + LayoutUtil::ClearLayout(&shape_); + InitChildren(shape_, init_value, &root_); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index).data_; + return Lookup(index)->data; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index).data_; + return &Lookup(index)->data; } template -ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) { - ShapeTree* node = this; - for (auto& i : index) { +internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { + Node* node = &root_; + for (const int64 i : index) { CHECK_GE(i, 0); - CHECK_LT(i, node->elements_.size()); - node = node->elements_[i].get(); + CHECK_LT(i, node->children.size()); + node = node->children[i].get(); } - return *node; + return node; } template -const ShapeTree& ShapeTree::Lookup(const ShapeIndex& index) const { - return const_cast*>(this)->Lookup(index); +const internal::ShapeTreeNode* ShapeTree::Lookup( + const ShapeIndex& index) const { + return const_cast(this)->Lookup(index); } /* static */ template -tensorflow::Status ShapeTree::ForEachHelperMutable( - ShapeIndex* index, ShapeTree* shape_tree, - ShapeTree::MutableVisitorFunction func) { - TF_RETURN_IF_ERROR( - func(*index, shape_tree->elements_.empty(), &shape_tree->data_)); - for (int i = 0; i < shape_tree->elements_.size(); ++i) { +Status ShapeTree::ForEachHelper(const VisitorFunction& func, + const Node& node, ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, node.children.empty(), node.data)); + for (int64 i = 0; i < node.children.size(); ++i) { index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachHelperMutable(index, shape_tree->elements_[i].get(), func)); + TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); index->pop_back(); } - - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ template -tensorflow::Status ShapeTree::ForEachHelper( - ShapeIndex* index, const ShapeTree& shape_tree, - ShapeTree::VisitorFunction func) { - TF_RETURN_IF_ERROR( - func(*index, shape_tree.elements_.empty(), shape_tree.data_)); - for (int i = 0; i < shape_tree.elements_.size(); ++i) { +Status ShapeTree::ForEachMutableHelper(const MutableVisitorFunction& func, + Node* node, ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, node->children.empty(), &node->data)); + for (int64 i = 0; i < node->children.size(); ++i) { index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(index, *shape_tree.elements_[i], func)); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, node->children[i].get(), index)); index->pop_back(); } - - return tensorflow::Status::OK(); + return Status::OK(); } template -tensorflow::Status ShapeTree::ForEachElement( - ShapeTree::VisitorFunction func) const { +Status ShapeTree::ForEachElement(const VisitorFunction& func) const { ShapeIndex index; - return ForEachHelper(&index, *this, func); + return ForEachHelper(func, root_, &index); } template -tensorflow::Status ShapeTree::ForEachMutableElement( - ShapeTree::MutableVisitorFunction func) { +Status ShapeTree::ForEachMutableElement(const MutableVisitorFunction& func) { ShapeIndex index; - return ForEachHelperMutable(&index, this, func); + return ForEachMutableHelper(func, &root_, &index); } } // namespace xla diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index d37f536b755d1feca57360edf950329197ba2dd4..efb6f422e008221c2f7d98e066c8aa6ae7bbf426 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -35,6 +35,9 @@ class ShapeTreeTest : public ::testing::Test { array_shape_})}); } + void TestShapeConstructor(const Shape& shape, int expected_num_nodes); + void TestInitValueConstructor(const Shape& shape, int expected_num_nodes); + // An array shape (non-tuple). Shape array_shape_; @@ -45,6 +48,81 @@ class ShapeTreeTest : public ::testing::Test { Shape nested_tuple_shape_; }; +TEST_F(ShapeTreeTest, DefaultConstructor) { + ShapeTree int_tree; + EXPECT_TRUE(ShapeUtil::IsNil(int_tree.shape())); + + ShapeTree bool_tree; + EXPECT_TRUE(ShapeUtil::IsNil(bool_tree.shape())); +} + +void ShapeTreeTest::TestShapeConstructor(const Shape& shape, + int expected_num_nodes) { + ShapeTree int_tree(shape); + int num_nodes = 0; + TF_CHECK_OK(int_tree.ForEachElement( + [&num_nodes](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { + EXPECT_EQ(0, data); + ++num_nodes; + return Status::OK(); + })); + EXPECT_EQ(expected_num_nodes, num_nodes); + + ShapeTree bool_tree(shape); + num_nodes = 0; + TF_CHECK_OK(bool_tree.ForEachElement( + [&num_nodes](const ShapeIndex& /*index*/, bool /*is_leaf*/, bool data) { + EXPECT_EQ(false, data); + ++num_nodes; + return Status::OK(); + })); + EXPECT_EQ(expected_num_nodes, num_nodes); +} + +TEST_F(ShapeTreeTest, ShapeConstructor) { + TestShapeConstructor(array_shape_, 1); + TestShapeConstructor(tuple_shape_, 4); + TestShapeConstructor(nested_tuple_shape_, 10); +} + +void ShapeTreeTest::TestInitValueConstructor(const Shape& shape, + int expected_num_nodes) { + ShapeTree tree(shape, 42); + int num_nodes = 0; + TF_CHECK_OK(tree.ForEachElement( + [&num_nodes](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { + EXPECT_EQ(42, data); + ++num_nodes; + return Status::OK(); + })); + EXPECT_EQ(expected_num_nodes, num_nodes); + + num_nodes = 0; + TF_CHECK_OK(tree.ForEachMutableElement( + [&num_nodes](const ShapeIndex& /*index*/, bool /*is_leaf*/, int* data) { + EXPECT_EQ(42, *data); + *data = num_nodes; + ++num_nodes; + return Status::OK(); + })); + EXPECT_EQ(expected_num_nodes, num_nodes); + + num_nodes = 0; + TF_CHECK_OK(tree.ForEachElement( + [&num_nodes](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { + EXPECT_EQ(num_nodes, data); + ++num_nodes; + return Status::OK(); + })); + EXPECT_EQ(expected_num_nodes, num_nodes); +} + +TEST_F(ShapeTreeTest, InitValueConstructor) { + TestInitValueConstructor(array_shape_, 1); + TestInitValueConstructor(tuple_shape_, 4); + TestInitValueConstructor(nested_tuple_shape_, 10); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; @@ -57,6 +135,15 @@ TEST_F(ShapeTreeTest, ArrayShape) { // Test the copy constructor. ShapeTree copy{shape_tree}; EXPECT_EQ(123, copy.element({})); + + // Mutate the copy, and ensure the original doesn't change. + *copy.mutable_element({}) = 99; + EXPECT_EQ(99, copy.element({})); + EXPECT_EQ(123, shape_tree.element({})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(123, copy.element({})); } TEST_F(ShapeTreeTest, TupleShape) { @@ -77,7 +164,7 @@ TEST_F(ShapeTreeTest, TupleShape) { TF_CHECK_OK(shape_tree.ForEachElement( [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) { sum += data; - return tensorflow::Status::OK(); + return Status::OK(); })); EXPECT_EQ(66, sum); @@ -92,12 +179,23 @@ TEST_F(ShapeTreeTest, TupleShape) { TF_CHECK_OK(shape_tree.ForEachMutableElement( [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int* data) { *data = 0; - return tensorflow::Status::OK(); + return Status::OK(); })); EXPECT_EQ(0, shape_tree.element({})); EXPECT_EQ(0, shape_tree.element({0})); EXPECT_EQ(0, shape_tree.element({1})); EXPECT_EQ(0, shape_tree.element({2})); + EXPECT_EQ(1, copy.element({})); + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1})); + EXPECT_EQ(-100, copy.element({2})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(0, copy.element({})); + EXPECT_EQ(0, copy.element({0})); + EXPECT_EQ(0, copy.element({1})); + EXPECT_EQ(0, copy.element({2})); } TEST_F(ShapeTreeTest, NestedTupleShape) { @@ -116,6 +214,23 @@ TEST_F(ShapeTreeTest, NestedTupleShape) { EXPECT_EQ(42, copy.element({0})); EXPECT_EQ(123, copy.element({1, 1})); EXPECT_EQ(-100, copy.element({2, 0, 1})); + + // Mutate the copy, and ensure the original doesn't change. + *copy.mutable_element({0}) = 1; + *copy.mutable_element({1, 1}) = 2; + *copy.mutable_element({2, 0, 1}) = 3; + EXPECT_EQ(1, copy.element({0})); + EXPECT_EQ(2, copy.element({1, 1})); + EXPECT_EQ(3, copy.element({2, 0, 1})); + EXPECT_EQ(42, shape_tree.element({0})); + EXPECT_EQ(123, shape_tree.element({1, 1})); + EXPECT_EQ(-100, shape_tree.element({2, 0, 1})); + + // Test the assignment operator. + copy = shape_tree; + EXPECT_EQ(42, copy.element({0})); + EXPECT_EQ(123, copy.element({1, 1})); + EXPECT_EQ(-100, copy.element({2, 0, 1})); } TEST_F(ShapeTreeTest, InvalidIndexingTuple) { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 57d91e4bfc1145faa25c9b5c57422c7653d4a163..ccc1dc63e78f8cb5aeaa5664a0d6917898db26b3 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/compiler/xla/index_util.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -200,7 +202,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { - shape->mutable_layout()->add_minor_to_major(ShapeUtil::Rank(*shape)); + shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); } @@ -293,7 +295,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { std::vector new_elements(tuple.tuple_shapes().begin() + start, tuple.tuple_shapes().begin() + limit); - return ShapeUtil::MakeTupleShape(new_elements); + return MakeTupleShape(new_elements); } /* static */ bool ShapeUtil::IsOpaque(const Shape& shape) { @@ -307,7 +309,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { if (shape.element_type() != element_type) { return false; } - if (shape.dimensions_size() != ShapeUtil::Rank(shape)) { + if (shape.dimensions_size() != Rank(shape)) { return false; } int64 i = 0; @@ -321,7 +323,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK_EQ(shape.dimensions_size(), ShapeUtil::Rank(shape)); + CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies()); @@ -332,7 +334,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && ShapeUtil::Rank(shape) == 0; + return shape.element_type() == F32 && Rank(shape) == 0; } /* static */ string ShapeUtil::HumanString(const Shape& shape) { @@ -430,13 +432,12 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } Shape result; if (layout_string.empty()) { - result = ShapeUtil::MakeShape(primitive_type, dimensions); + result = MakeShape(primitive_type, dimensions); } else { TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + result = MakeShapeWithLayout(primitive_type, dimensions, min2maj); } TF_DCHECK_OK(ValidateShape(result)); return result; @@ -466,7 +467,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape, int64 dimension_number) { if (dimension_number < 0) { - dimension_number += ShapeUtil::Rank(shape); + dimension_number += Rank(shape); } CHECK_GE(dimension_number, 0); return dimension_number; @@ -518,7 +519,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } int64 allocated_element_count; if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(ShapeUtil::Rank(shape), shape.layout().padded_dimensions_size()); + CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); allocated_element_count = 1; for (int64 dimension_size : shape.layout().padded_dimensions()) { allocated_element_count *= dimension_size; @@ -534,9 +535,9 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { const Shape& shape) { if (shape.element_type() == TUPLE) { // Tuple shape. - if (ShapeUtil::Rank(shape) != 0) { + if (Rank(shape) != 0) { return InvalidArgument("tuples must be rank-0; got rank %lld", - ShapeUtil::Rank(shape)); + Rank(shape)); } if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -556,13 +557,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString().c_str()); } - if (ShapeUtil::Rank(shape) != shape.dimensions_size()) { + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " "dimensions_size=%d", - ShapeUtil::Rank(shape), shape.dimensions_size()); + Rank(shape), shape.dimensions_size()); } - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( @@ -675,7 +676,7 @@ namespace { // Helper for ForEachSubshape which visits the subshapes of the given shape in // DFS pre-order starting with the index. Status ForEachSubshapeHelper(const Shape& shape, - const ShapeUtil::VisitorFunction func, + const ShapeUtil::VisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(shape)) { @@ -692,7 +693,7 @@ Status ForEachSubshapeHelper(const Shape& shape, // Helper for ForEachMutableSubshape which visits the subshapes of the given // shape in DFS pre-order starting with the index. Status ForEachMutableSubshapeHelper( - Shape* shape, const ShapeUtil::MutatingVisitorFunction func, + Shape* shape, const ShapeUtil::MutatingVisitorFunction& func, ShapeIndex* index) { TF_RETURN_IF_ERROR(func(shape, *index)); if (ShapeUtil::IsTuple(*shape)) { @@ -709,13 +710,13 @@ Status ForEachMutableSubshapeHelper( } // namespace /* static */ Status ShapeUtil::ForEachSubshape(const Shape& shape, - VisitorFunction func) { + const VisitorFunction& func) { ShapeIndex index; return ForEachSubshapeHelper(shape, func, &index); } /* static */ Status ShapeUtil::ForEachMutableSubshape( - Shape* shape, MutatingVisitorFunction func) { + Shape* shape, const MutatingVisitorFunction& func) { ShapeIndex index; return ForEachMutableSubshapeHelper(shape, func, &index); } @@ -728,9 +729,17 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { - new_shape.mutable_layout()->clear_minor_to_major(); + Layout* new_layout = new_shape.mutable_layout(); + new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { - new_shape.mutable_layout()->add_minor_to_major(index); + new_layout->add_minor_to_major(index); + } + if (shape.layout().padded_dimensions_size() > 0) { + new_layout->clear_padded_dimensions(); + for (auto dim : + Permute(permutation, shape.layout().padded_dimensions())) { + new_layout->add_padded_dimensions(dim); + } } } return new_shape; @@ -783,8 +792,7 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, auto unmodified_dim_pair = i < unmodified_dims.size() ? unmodified_dims[i] - : std::make_pair(ShapeUtil::Rank(shape_pre), - ShapeUtil::Rank(shape_post)); + : std::make_pair(Rank(shape_pre), Rank(shape_post)); if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) { return nil; } @@ -859,9 +867,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return false; } - CHECK_EQ(ShapeUtil::ElementsIn(input_shape), - ShapeUtil::ElementsIn(output_shape)); - if (ShapeUtil::ElementsIn(input_shape) == 0) { + CHECK_EQ(ElementsIn(input_shape), ElementsIn(output_shape)); + if (ElementsIn(input_shape) == 0) { return true; } @@ -975,21 +982,17 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // as input_shape/output_shape and the dimension-0-major layout. These two // shapes are used for conversion between logical linear indices and // multi-dimensional indices. - Shape input_shape_dim0_major = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); - Shape output_shape_dim0_major = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - output_shape.element_type(), - AsInt64Slice(output_shape.dimensions())); - - for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); - ++input_dim) { + Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); + Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); + + for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { if (input_shape.dimensions(input_dim) <= 1) { continue; } - std::vector input_unit_index(ShapeUtil::Rank(input_shape), 0); + std::vector input_unit_index(Rank(input_shape), 0); input_unit_index[input_dim] = 1; int64 logical_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major, @@ -1013,6 +1016,140 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } +/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( + const Shape& input_shape, const Shape& output_shape) { + int64 input_rank = Rank(input_shape); + int64 output_rank = Rank(output_shape); + + // First, calculate an alignment of the dimensions. A consecutive sequence of + // input dimensions and output dimensions belong to the same alignment part if + // the products of their dimension bounds are the same. In the easiest case, + // an alignment part consists of one input dimension and one output dimension + // which both have the same dimension bound. An alignment part specifies which + // dimensions need to be kept together in a physical layout if we want a + // reshape to be a bitcast. The order of the alignment parts is defined by the + // physical layout of the input shape, so when we construct the layout for the + // output shape we just process the alignment parts in this order, and then + // layout the dimensions belonging to each part in descending (major to minor) + // order. + + // Stores the input and output dimension numbers where each alignment part + // starts. + std::vector> alignment; + alignment.push_back({0, 0}); + + // Stores a mapping from the input dimension to the alignment part it belongs + // to. + std::vector dimension_to_alignment_index(input_rank); + int64 input_dimension_product = 1, output_dimension_product = 1; + for (int64 i = 0, j = 0; i < input_rank || j < output_rank;) { + // Check if we have reached the end of an alignment part. + if (input_dimension_product == output_dimension_product && + input_dimension_product > 1) { + alignment.push_back({i, j}); + input_dimension_product = output_dimension_product = 1; + } + if (input_dimension_product < output_dimension_product || + j == output_rank) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + dimension_to_alignment_index[i] = alignment.size() - 1; + input_dimension_product *= input_shape.dimensions(i); + ++i; + } else { + output_dimension_product *= output_shape.dimensions(j); + ++j; + } + } + if (input_dimension_product != output_dimension_product) { + return tensorflow::gtl::nullopt; + } + // We also need to store an end element so that we know where the last + // alignment part ends. + alignment.push_back({input_rank, output_rank}); + + // Now check if the physical layout can potentially be aligned to the output + // shape by changing the physical layout of the output shape. We need to check + // that all dimension numbers that belong to the same alignment part appear + // consecutively, and are in descending order. However we can ignore any + // trivial dimension bounds of 1, because they can be placed anywhere. + auto input_dimension_numbers = input_shape.layout().minor_to_major(); + std::vector output_layout; + output_layout.reserve(output_rank); + for (int64 i = 0; i < input_rank;) { + int64 current_dimension_number = input_dimension_numbers[i]; + + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(current_dimension_number) == 1) { + ++i; + continue; + } + + // Calculate the number of non-trivial dimension bounds in the input shape + // belonging to the current alignment part. + const int64 current_alignment_index = + dimension_to_alignment_index[current_dimension_number]; + // Because of the special end element that we added, we can be sure that + // 'current_alignment_index' is < alignment.size() - 1. + CHECK_LT(current_alignment_index, alignment.size() - 1); + int64 num_non_trivial_dimensions_in_alignment_part = 0; + for (int64 j = alignment[current_alignment_index].first; + j < alignment[current_alignment_index + 1].first; ++j) { + if (input_shape.dimensions(j) != 1) { + ++num_non_trivial_dimensions_in_alignment_part; + } + } + + // Check that the following 'num_non_trivial_dimensions_in_alignment_part' + // dimension numbers (ignoring dimension numbers with dimension bound 1) are + // in descending order and belong to the current alignment part. + for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; + ++i, ++j) { + if (i == input_rank) { + return tensorflow::gtl::nullopt; + } + // Skip trivial dimensions with a bound of 1. + if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { + --j; + continue; + } + // If the current dimension number belongs to a different alignment part, + // or the dimension numbers are not in descending order, we can return + // early. + if (dimension_to_alignment_index[input_dimension_numbers[i]] != + current_alignment_index || + input_dimension_numbers[i] > current_dimension_number) { + return tensorflow::gtl::nullopt; + } + current_dimension_number = input_dimension_numbers[i]; + } + + // The output dimension numbers that belong to the current alignment part + // need to appear in the same descending order as in the input. Again, we + // can skip dimensions with a bound of 1. + for (int64 j = alignment[current_alignment_index + 1].second - 1; + j >= alignment[current_alignment_index].second; --j) { + if (output_shape.dimensions(j) != 1) { + output_layout.push_back(j); + } + } + } + // Now add all the dimensions with dimension bound 1 at the end of + // 'output_layout'. + for (int64 i = 0; i < output_rank; ++i) { + if (output_shape.dimensions(i) == 1) { + output_layout.push_back(i); + } + } + CHECK_EQ(output_layout.size(), output_rank); + Shape output_shape_with_layout = MakeShapeWithLayout( + output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), + output_layout); + CHECK(ReshapeIsBitcast(input_shape, output_shape_with_layout)); + return output_shape_with_layout; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); @@ -1047,4 +1184,31 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +/* static */ void ShapeUtil::ForEachIndex( + const Shape& shape, tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function) { + DCHECK_EQ(Rank(shape), base.size()); + DCHECK_EQ(incr.size(), base.size()); + DCHECK_EQ(count.size(), base.size()); + const Layout& layout = shape.layout(); + int64 rank = layout.minor_to_major_size(); + // Allows handling R0 arrays, such that the visitor function will be called + // once with the proper empty indexes. + int64 n = -1; + std::vector indexes(base.begin(), base.end()); + while (n < rank && visitor_function(indexes)) { + // Increments dimensions in minor to major order. + for (n = 0; n < rank; ++n) { + int64 dim = layout.minor_to_major(n); + indexes[dim] += incr[dim]; + if (indexes[dim] < base[dim] + count[dim]) { + break; + } + indexes[dim] = base[dim]; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 68e138e6aca9d2cf157466eca1ea6960e3c448e8..aaf8e84cfecb89080d690c66acd4f8d50ee17d56 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -299,13 +300,14 @@ class ShapeUtil { // pre-order starting with the entire shape (index {}). using VisitorFunction = std::function; - static Status ForEachSubshape(const Shape& shape, VisitorFunction func); + static Status ForEachSubshape(const Shape& shape, + const VisitorFunction& func); // Mutating variant of ForEachSubshape. using MutatingVisitorFunction = std::function; static Status ForEachMutableSubshape(Shape* shape, - MutatingVisitorFunction func); + const MutatingVisitorFunction& func); // Removes all degenerate dimensions (size one) from the given shape. The // stripped minor_to_major preserves the relative ordering of non-degenerate @@ -377,6 +379,15 @@ class ShapeUtil { static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); + // Find a physical layout for 'output_shape' such that + // ShapeUtil::ReshapeIsBitcast(input_shape, output_shape_with_layout) returns + // true (where 'output_shape_with_layout' is 'output_shape' with the found + // layout). The layout of 'input_shape' is kept fixed. Returns + // 'output_shape_with_layout' if such a layout can be found, and an error + // otherwise. + static tensorflow::gtl::optional AlignLayouts( + const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` @@ -390,6 +401,19 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); + // Iterates through all the shape indexes, in minor to major order, starting + // from the base indexes, incrementing by the incr steps, up to count + // (index[i] < base[i] + count[i]), and calls the visitor_function with the + // current index. + // The visitor_function visitor function should return true if it wants to + // continue, or false otherwise. + using IndexVisitorFunction = std::function&)>; + static void ForEachIndex(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const IndexVisitorFunction& visitor_function); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 9e6b243611b57d38339a8f6460c655255f60899d..73538b8b88ecf14c00854d3c31715af8189bc21d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,14 +16,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { +using ::testing::ElementsAre; + TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1)); @@ -446,21 +449,21 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) { TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) { // All input dimensions should be unmodified. One of the output dimensions is // modified because the output rank is larger by one. - EXPECT_EQ(3, - ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {1, 1, 1}), - ShapeUtil::MakeShape(S32, {1, 1, 1, 1})) - .size()); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {1, 1, 1}), + ShapeUtil::MakeShape(S32, {1, 1, 1, 1})), + ElementsAre(std::make_pair(0, 0), std::make_pair(1, 1), + std::make_pair(2, 2))); } TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { @@ -468,11 +471,10 @@ TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) { // 4, 1, 3, 5, 6, 7 // | // 2, 6, 1, 5, 1, 42 - EXPECT_TRUE( - ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape( - ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), - ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), - std::vector>({{3, 3}}))); + EXPECT_THAT(ShapeUtil::DimensionsUnmodifiedByReshape( + ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}), + ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})), + ElementsAre(std::make_pair(3, 3))); } TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) { @@ -521,5 +523,58 @@ TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensions) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(4, 3, 2, 1, 0, 5)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); + + aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {3, 2, 4, 35, 11})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(3, 2, 1, 0, 4)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +TEST(AlignmentTest, AlignLayoutsWithTrivialDimensions) { + Shape input = + ShapeUtil::MakeShapeWithLayout(xla::F32, {1, 3, 8, 1, 5, 7, 1, 11, 1, 1}, + {5, 0, 4, 2, 1, 3, 6, 7, 9, 8}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {1, 4, 1, 3, 2, 7, 5, 11, 1})); + EXPECT_TRUE(aligned_shape); + EXPECT_THAT(aligned_shape.value().layout().minor_to_major(), + ElementsAre(6, 5, 4, 3, 1, 7, 0, 2, 8)); + EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(input, aligned_shape.value())); +} + +// A test case where the consecutive elements of the input shape belonging to +// the same layout part are not in descending order. +TEST(AlignmentTest, AlignLayoutsWithoutTrivialDimensionsWrongInputLayout) { + // Same physical layout as in AlignLayoutsWithoutTrivialDimensions, except + // that the first two dimension numbers are exchanged. + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {2, 3, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 7, 5, 11})); + EXPECT_FALSE(aligned_shape); +} + +// A test case where the physical layout of the input shape does not place all +// dimensions that belong to the same alignment part consecutively. +TEST(AlignmentTest, + AlignLayoutsWithoutTrivialDimensionsNonConsecutiveAlignmentPart) { + Shape input = ShapeUtil::MakeShapeWithLayout(xla::F32, {3, 8, 5, 7, 11}, + {3, 2, 1, 0, 4}); + auto aligned_shape = ShapeUtil::AlignLayouts( + input, ShapeUtil::MakeShape(xla::F32, {4, 3, 2, 5, 77})); + EXPECT_FALSE(aligned_shape); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc index 4e7b9161db5c7e01a4b80da49bdded025eaf298a..dead17cdfa1e9f19e0ecfbc071e74e159ae82b5f 100644 --- a/tensorflow/compiler/xla/status_macros_test.cc +++ b/tensorflow/compiler/xla/status_macros_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/test.h" namespace xla { @@ -40,15 +40,15 @@ Status RetCheckSuccess() { TEST(StatusMacros, RetCheckFailing) { Status status = RetCheckFail(); EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); - EXPECT_MATCH(status.error_message(), - xla::testing::ContainsRegex("RET_CHECK failure.*2 > 3")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("RET_CHECK failure.*2 > 3")); } TEST(StatusMacros, RetCheckFailingWithExtraMessage) { Status status = RetCheckFailWithExtraMessage(); EXPECT_EQ(status.code(), tensorflow::error::INTERNAL); - EXPECT_MATCH(status.error_message(), - xla::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message")); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message")); } TEST(StatusMacros, RetCheckSucceeding) { @@ -73,7 +73,7 @@ Status ReturnStatusError() { return (tensorflow::errors::Internal("foobar")); } using StatusReturningFunction = std::function; -StatusOr CallStatusReturningFunction(StatusReturningFunction func) { +StatusOr CallStatusReturningFunction(const StatusReturningFunction& func) { TF_RETURN_IF_ERROR(func()); return 42; } diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index d98eb2793363ac855b43f88eb4201f34a3b7693b..d3bc3e9225fd65b9ded18e970ecb7c81588078fe 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" namespace xla { diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h new file mode 100644 index 0000000000000000000000000000000000000000..87a8c5f3a528289d47c1729ae6719aae47037c36 --- /dev/null +++ b/tensorflow/compiler/xla/test.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_ +#define TENSORFLOW_COMPLIER_XLA_TEST_H_ + +// This header includes gmock.h and enables the use of gmock matchers in tests +// in third_party/tensorflow/compiler/xla. +// +// Test including this header can use the macros EXPECT_THAT(...) and +// ASSERT_THAT(...) in combination with gmock matchers. +// Example: +// std::vector vec = Foo(); +// EXPECT_THAT(vec, ::testing::ElementsAre(1,2,3)); +// +// For more details on gmock matchers see: +// https://github.com/google/googletest/blob/master/googlemock/docs/CheatSheet.md#matchers +// +// The advantages of using gmock matchers instead of self defined matchers are +// better error messages, more maintainable tests and more test coverage. +// +// Note that while the use of gmock matchers is allowed in the xla project, the +// use of mocks is disallowed in the whole tensorflow project! + +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) +#include "testing/base/public/gmock.h" +#else +#include +#include +#endif + +#include "tensorflow/core/platform/test.h" + +#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_ diff --git a/tensorflow/compiler/xla/test_helpers.cc b/tensorflow/compiler/xla/test_helpers.cc deleted file mode 100644 index 02abfdeab80ee34c79e8d54b825937d6fc4b4053..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/test_helpers.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/regexp.h" - -namespace xla { -namespace testing { - -AssertionResult::AssertionResult(const AssertionResult& other) - : success_(other.success_), - message_(other.message_ != nullptr ? new std::string(*other.message_) - : static_cast(nullptr)) { -} - -// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. -AssertionResult AssertionResult::operator!() const { - AssertionResult negation(!success_); - if (message_ != nullptr) negation << *message_; - return negation; -} - -AssertionResult& AssertionResult::operator=(const AssertionResult& ar) { - success_ = ar.success_; - message_.reset(ar.message_ != nullptr ? new std::string(*ar.message_) - : nullptr); - return *this; -} - -AssertionResult AssertionFailure() { return AssertionResult(false); } - -AssertionResult AssertionSuccess() { return AssertionResult(true); } - -std::function ContainsRegex( - const tensorflow::StringPiece regex) { - return [regex](const tensorflow::StringPiece to_test) { - if (RE2::PartialMatch( - tensorflow::RegexpStringPiece(to_test.data(), to_test.size()), - tensorflow::RegexpStringPiece(regex.data(), regex.size()))) { - return true; - } else { - LOG(ERROR) << "Expected to find " << regex << " in " << to_test; - return false; - } - }; -} - -std::function HasSubstr( - const tensorflow::StringPiece part) { - return [part](const tensorflow::StringPiece whole) { - return whole.contains(part); - }; -} - -} // namespace testing -} // namespace xla diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index f923d9f36c878c1ae4e37f052a84e9c2a279b4ed..634cdb5aa29651b08090ff99f0a6cafb9facb645 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -39,286 +39,6 @@ class Literal; namespace testing { -class AssertionResult { - public: - explicit AssertionResult(bool success) : success_(success) {} - - // Returns true iff the assertion succeeded. - operator bool() const { return success_; } // NOLINT - - // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. - AssertionResult operator!() const; - - // Returns the text streamed into this AssertionResult. Test assertions - // use it when they fail (i.e., the predicate's outcome doesn't match the - // assertion's expectation). When nothing has been streamed into the - // object, returns an empty string. - const char* message() const { - return message_ != nullptr ? message_->c_str() : ""; - } - - // Streams a custom failure message into this object. - template - AssertionResult& operator<<(const T& value) { - AppendMessage(::testing::Message() << value); - return *this; - } - - // Allows streaming basic output manipulators such as endl or flush into - // this object. - AssertionResult& operator<<( - std::ostream& (*basic_manipulator)(std::ostream& stream)) { - AppendMessage(::testing::Message() << basic_manipulator); - return *this; - } - - // Copy operator. - AssertionResult(const AssertionResult& ar); - - // Assignment operator. - AssertionResult& operator=(const AssertionResult&); - - private: - // Appends the contents of message to message_. - void AppendMessage(const ::testing::Message& a_message) { - if (message_ == nullptr) message_.reset(new std::string); - message_->append(a_message.GetString().c_str()); - } - - bool success_ = false; - - // Stores the message describing the condition in case the - // expectation construct is not satisfied with the predicate's - // outcome. Referenced via a pointer to avoid taking too much stack - // frame space with test assertions. - std::unique_ptr message_; -}; - -AssertionResult AssertionFailure(); - -AssertionResult AssertionSuccess(); - -std::function ContainsRegex( - const tensorflow::StringPiece regex); - -std::function HasSubstr( - const tensorflow::StringPiece part); - -// Matcher for a vector of same-type values for which operator= is -// defined. -template -std::function& actual)> VectorMatcher( - const std::vector& expected) { - return [expected](const std::vector& actual) -> AssertionResult { - int len = expected.size(); - if (actual.size() != len) { - return AssertionFailure() << "Actual values len of " << actual.size() - << " != expected.size " << len; - } - for (int i = 0; i < len; ++i) { - if (actual[i] != expected[i]) { - return AssertionFailure() << "Element " << i << " actual " << actual[i] - << " != " << expected[i]; - } - } - return AssertionSuccess(); - }; -} - -// Approximate matcher for a vector of floats or similar. -template -std::function& actual)> -ApproxVectorMatcher(const std::vector& expected, float abs_diff, - float rel_diff) { - return [abs_diff, rel_diff, - expected](const std::vector& actual) -> AssertionResult { - int len = expected.size(); - if (actual.size() != len) { - AssertionResult ar = AssertionFailure() << "Actual values len of " - << actual.size() - << " != expected.size " << len; - LOG(ERROR) << ar.message(); - return ar; - } - for (int i = 0; i < len; ++i) { - T diff = actual[i] - expected[i]; - if (diff < 0) { - diff *= -1; - } - if (diff > abs_diff) { - T rdiff = (expected[i] != 0 ? diff / expected[i] : 0.0 * expected[i]); - if (rdiff > rel_diff) { - AssertionResult ar = AssertionFailure() - << "Element " << i << " actual " << actual[i] - << " != " << expected[i] - << "( abs_diff = " << diff - << ", rel_diff = " << rdiff << ")"; - LOG(ERROR) << ar.message(); - return ar; - } - } - } - return AssertionSuccess(); - }; -} - -// Matches a vector of same-type values against another, succeeding so -// long as they have the same length and every value in 'actual' -// matches one in 'expected.' Does not verify an exhaustive -// one-to-one mapping between the two. -template -std::function& actual)> -UnorderedElementsAre(const std::vector& expected) { - return [expected](const std::vector& actual) -> AssertionResult { - if (actual.size() != expected.size()) { - return AssertionFailure() << "sizes don't match"; - } - for (auto a : actual) { - bool found = false; - for (auto e : expected) { - if (a == e) { - found = true; - break; - } - } - if (!found) { - return AssertionFailure() << "actual element " << a - << " not in expected"; - } - } - return AssertionSuccess(); - }; -} - -// Overloaded cover functions for UnorderedElementsAre, for the numbers -// of values used in practice. -template -std::function& actual)> UnorderedMatcher( - T a) { - std::vector expected; - expected.push_back(a); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d, T e) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - expected.push_back(e); - return testing::UnorderedElementsAre(expected); -} - -template -std::function& actual)> UnorderedMatcher( - T a, T b, T c, T d, T e, T f) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - expected.push_back(e); - expected.push_back(f); - return testing::UnorderedElementsAre(expected); -} - -// Overloaded cover functions for VectorMatcher for the numbers of -// elements used in practice. -template -std::function& actual)> OrderedMatcher( - T a) { - std::vector expected; - expected.push_back(a); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b, T c) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - return testing::VectorMatcher(expected); -} - -template -std::function& actual)> OrderedMatcher( - T a, T b, T c, T d) { - std::vector expected; - expected.push_back(a); - expected.push_back(b); - expected.push_back(c); - expected.push_back(d); - return testing::VectorMatcher(expected); -} - -// Convert a RepeatedField to a flat vector. -template -std::vector PBToVec(const tensorflow::protobuf::RepeatedField rf) { - return std::vector(rf.begin(), rf.end()); -} - -// Convert a List to a flat vector. -template -std::vector ListToVec(const std::list& l) { - return std::vector(l.begin(), l.end()); -} - -// Convert a Set to a flat vector. -template -std::vector SetToVec(const std::set& c) { - return std::vector(c.begin(), c.end()); -} - -// Convert an Array to a flat vector. -template -std::vector Array2DToVec(const Array2D& a) { - return std::vector(a.data(), a.data() + a.num_elements()); -} - namespace internal_status { inline const ::tensorflow::Status& GetStatus( const ::tensorflow::Status& status) { @@ -347,9 +67,4 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { ASSERT_EQ(tensorflow::Status::OK(), \ xla::testing::internal_status::GetStatus(expression)) -// Macros that apply a Matcher to a Value, returning an -// AssertionResult which gets digested by a standard gunit macro. -#define EXPECT_MATCH(V, M) EXPECT_TRUE((M)((V))) -#define ASSERT_MATCH(V, M) ASSERT_TRUE(M(V)) - #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index c7cbbdf4999970b0a09660ddadc31a068c752a55..e0c2b9ab09c28a7b7a31917b9250bdca8016d1e0 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -103,6 +104,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -198,11 +200,13 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:pool", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", ], ) @@ -889,6 +893,7 @@ xla_test( name = "copy_test", srcs = ["copy_test.cc"], deps = [ + ":client_library_test_base", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", @@ -1204,12 +1209,12 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", - "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1359,6 +1364,7 @@ cc_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index d18511a6b4a98d42640ed22f6aa69c2e66465f8a..319cd2c6fd18e328435613de86fa2ad1d84f90aa 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -242,6 +242,150 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +TEST_F(ArrayElementwiseOpTest, DivS32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff, + -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101, + 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX}; + // clang-format on + + std::vector dividends, divisors, quotients, remainders; + for (int32 divisor : vals) { + if (divisor != 0) { + for (int32 dividend : vals) { + // Avoid integer overflow. + if (dividend != INT32_MIN || divisor != -1) { + dividends.push_back(dividend); + divisors.push_back(divisor); + quotients.push_back(dividend / divisor); + remainders.push_back(dividend % divisor); + } + } + } + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + builder.Div(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + builder.Rem(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } +} + +TEST_F(ArrayElementwiseOpTest, DivU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000, + 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on + + std::vector dividends, divisors, quotients, remainders; + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + dividends.push_back(dividend); + divisors.push_back(divisor); + quotients.push_back(dividend / divisor); + remainders.push_back(dividend % divisor); + } + } + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + builder.Div(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + { + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", + &builder, ÷nd); + builder.Rem(dividend, builder.ConstantR1(divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1( @@ -486,6 +630,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); +} + TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); @@ -620,12 +776,14 @@ TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f}); - auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN}); + auto lhs = + builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); + auto rhs = + builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); auto minimum = builder.Pow(lhs, rhs); - ComputeAndCompareR1(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {}, - error_spec_); + ComputeAndCompareR1( + &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { @@ -1667,9 +1825,9 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { auto concatenated = builder.Add(x, x); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::ContainsRegex( + "Expected non-opaque argument for lhs of binary operation")); } // Regression test for b/31927799. "slice - y" is fused and requires implicit diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index c7b533b80f1901a32324a15a8f6584e628a4ad30..a67f18a44e10249bb4674624476c617d6f5c5ce5 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -23,12 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { @@ -45,8 +44,8 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); LOG(INFO) << "status received: " << computation.status(); - EXPECT_MATCH(computation.status().error_message(), - testing::HasSubstr("shape has invalid")); + EXPECT_THAT(computation.status().error_message(), + ::testing::HasSubstr("shape has invalid")); } TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 63744afb4ea72006262aad74e9b8d75a09b107e6..901bed5f1488d6df19b6b0d3a1772d07fb60bf6d 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -24,16 +24,16 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { using BroadcastSimpleTest = ClientLibraryTestBase; +using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { ComputationBuilder b(client_, TestName()); @@ -89,6 +89,33 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +// Tests implicit broadcasting of PREDs. +XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { + ComputationBuilder b(client_, TestName()); + + Array2D x_vals(2, 1); + x_vals(0, 0) = true; + x_vals(1, 0) = false; + Array3D y_vals(2, 2, 1); + y_vals(0, 0, 0) = false; + y_vals(0, 1, 0) = false; + y_vals(1, 0, 0) = true; + y_vals(1, 1, 0) = true; + + ComputationDataHandle x, y; + auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); + auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); + b.LogicalAnd(x, y, /*broadcast_dimensions=*/{1, 2}); + + Array3D expected(2, 2, 1); + expected(0, 0, 0) = false; + expected(0, 1, 0) = false; + expected(1, 0, 0) = true; + expected(1, 1, 0) = false; + + ComputeAndCompareR3(&b, expected, {x_data.get(), y_data.get()}); +} + XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { ComputationBuilder b(client_, TestName()); b.Broadcast(b.ConstantR1({}), {2}); @@ -127,6 +154,251 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } +struct R3ImplicitBroadcastSpec { + std::array output_bounds; + std::array minor2major_layout; + std::array input_bounds; + HloOpcode op; +} kR3ImplicitBroadcastTestCases[] = { + {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd}, + {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd}, + {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum}, + {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd}, +}; + +class BroadcastR3ImplicitTest + : public BroadcastSimpleTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { + const R3ImplicitBroadcastSpec& spec = GetParam(); + ComputationBuilder builder(client_, TestName()); + const Shape r3_shape = ShapeUtil::MakeShapeWithLayout( + F32, spec.output_bounds, spec.minor2major_layout); + Array3D r3_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + r3_array.FillRandom(1.0, 2.5, 56789); + auto r3_input = + LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(r3_array), + LayoutUtil::MakeLayout(spec.minor2major_layout)); + std::unique_ptr r3_global_data = + client_->TransferToServer(*r3_input).ConsumeValueOrDie(); + + const Shape r3_implicit_shape = ShapeUtil::MakeShapeWithLayout( + F32, spec.input_bounds, spec.minor2major_layout); + Array3D r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1], + spec.input_bounds[2]); + r3_implicit_array.FillRandom(1.0, 0.2, 56789); + auto r3_implicit_input = LiteralUtil::Relayout( + *LiteralUtil::CreateR3FromArray3D(r3_implicit_array), + LayoutUtil::MakeLayout(spec.minor2major_layout)); + std::unique_ptr r3_implicit_global_data = + client_->TransferToServer(*r3_implicit_input).ConsumeValueOrDie(); + + auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); + auto r3_parameter = builder.Parameter(1, r3_shape, "input"); + ComputationDataHandle op; + switch (spec.op) { + case HloOpcode::kMinimum: { + auto tmp_op = builder.Min(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + case HloOpcode::kMaximum: { + auto tmp_op = builder.Max(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + case HloOpcode::kMultiply: { + auto tmp_op = builder.Mul(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + break; + } + default: { + // Default to Add + auto tmp_op = builder.Add(r3_implicit_parameter, r3_parameter); + op.Swap(&tmp_op); + } + } + + Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], + spec.output_bounds[2]); + auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], + indices[1] % spec.input_bounds[1], + indices[2] % spec.input_bounds[2]); + float r3 = r3_array(indices[0], indices[1], indices[2]); + switch (spec.op) { + case HloOpcode::kMinimum: { + *value = std::min(r3_implicit, r3); + break; + } + case HloOpcode::kMaximum: { + *value = std::max(r3_implicit, r3); + break; + } + case HloOpcode::kMultiply: { + *value = r3_implicit * r3; + break; + } + default: { + // Default to Add + *value = r3_implicit + r3; + break; + } + } + }); + + int n1 = expected_array.n1(); + int n2 = expected_array.n2(); + int n3 = expected_array.n3(); + for (int64 i = 0; i < n1; i++) { + for (int64 j = 0; j < n2; j++) { + for (int64 k = 0; k < n3; k++) { + Each({i, j, k}, &expected_array(i, j, k)); + } + } + } + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + ComputeAndCompareLiteral( + &builder, *expected, + {r3_implicit_global_data.get(), r3_global_data.get()}, + ErrorSpec(1e-7, 1e-7)); +} + +INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, + BroadcastR3ImplicitTest, + ::testing::ValuesIn(kR3ImplicitBroadcastTestCases)); + +// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle r1h; + ComputationDataHandle r3h; + + Array3D r1d = {{{1}}, {{2}}}; + Array3D r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; + auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h); + auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h); + + b.Add(r3h, r1h); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = + b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r3 = b.ConstantLiteral( + *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + b.Add(r3, r1); + + auto expected = + LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { + ComputationBuilder b(client_, TestName()); + auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + b.Add(r2, r1); + + auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); @@ -220,8 +492,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH(result_status.status().error_message(), - testing::ContainsRegex("broadcast dimension 0 mismatch")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("broadcast dimension 0 mismatch")); } XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { @@ -233,9 +505,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("binary op BINOP_ADD with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -247,9 +518,8 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("binary op BINOP_ADD with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 1796a732e543b7f040adf6055e349d72cfcfad6e..16d4282466c8a6db5e3e34bfa9deb86fd339c27b 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -265,6 +265,37 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); } +TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { + auto builder = HloComputation::Builder(TestName()); + Array3D input_vals(2, 3, 4); + input_vals.FillRandom(1.0); + + Array4D expected(2, 3, 4, 5); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 4; ++k) { + for (int m = 0; m < 5; ++m) { + expected(i, j, k, m) = input_vals(i, j, k); + } + } + } + } + auto input = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR3FromArray3D(input_vals))); + + // Broadcast vector in dimensions 2 and 3. + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2})); + + // Create HLO module, compile, and execute. + auto hlo_module = MakeUnique(TestName()); + hlo_module->AddEntryComputation(builder.Build()); + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + + LiteralTestUtil::ExpectNear( + *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 0b5e6d512771cc6aebfd92af81bfdfa56d176088..9b96173aaa01199bdaf18d4b56d9f118432b2655 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 675c9fccb007f5a0a16b50618e849d3740877403..1bb1a1d6b4e4ce79413642b542cec8dd64ecba86 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -22,15 +22,17 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { +using ::testing::ContainsRegex; + class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { @@ -60,15 +62,15 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(result_one_arg.status().error_message(), - testing::ContainsRegex("takes 2")); + ASSERT_THAT(result_one_arg.status().error_message(), + ContainsRegex("takes 2")); auto result_zero_args = client_->Execute(computation, {}); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(result_zero_args.status().error_message(), - testing::ContainsRegex("takes 2")); + ASSERT_THAT(result_zero_args.status().error_message(), + ContainsRegex("takes 2")); } XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { @@ -99,22 +101,22 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 0")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 0")); // Shape mismatch in parameter 1 (rank) status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 1")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 1")); // Shape mismatch in parameter 1 (element type) status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); - ASSERT_MATCH(status.status().error_message(), - testing::ContainsRegex("expects parameter 1")); + ASSERT_THAT(status.status().error_message(), + ContainsRegex("expects parameter 1")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index e6f3225bb79fca99f189d1e7ae7913715c5c2246..d5acea32ef700dc802dd7900b7ec8d454112f3e8 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" @@ -43,12 +42,13 @@ void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, std::unique_ptr CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { - auto module_config = MakeUnique( + HloModuleConfig module_config( hlo_module->entry_computation()->ComputeProgramShape()); - module_config->set_fast_math_disabled(fast_math_disabled_); + module_config.set_fast_math_disabled(fast_math_disabled_); + hlo_module->set_config(module_config); return backend_->compiler() - ->Compile(std::move(hlo_module), std::move(module_config), - test_hlo_dumper_, backend_->default_stream_executor()) + ->Compile(std::move(hlo_module), test_hlo_dumper_, + backend_->default_stream_executor()) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 709ce5029c82d52fe7a577d1e4cf7ea6ec07cecb..1d998fe33ebf71a2b35f99a51038e874edacc046 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,43 +17,81 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class ComputeConstantTest : public ClientLibraryTestBase { +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class ComputeConstantTest : public ::testing::Test { public: + explicit ComputeConstantTest( + perftools::gputools::Platform* platform = nullptr, + tensorflow::gtl::ArraySlice disabled_pass_names = {}) + : platform_(platform) { + legacy_flags::HloPassPipelineFlags* flags = + legacy_flags::GetHloPassPipelineFlags(); + flags->xla_disable_hlo_passes = + tensorflow::str_util::Join(disabled_pass_names, ","); + } + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(::perftools::gputools::Platform* platform, + ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + StatusOr> ComputeConstantLiteral( - ComputationDataHandle operand, ComputationBuilder* builder, - Layout* output_layout = nullptr) { + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto remote_computed, builder->ComputeConstant(operand, output_layout)); - TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed)); + TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed)); return std::move(computed); } template - StatusOr ComputeConstantScalar(ComputationDataHandle operand, + StatusOr ComputeConstantScalar(Client* client, + const ComputationDataHandle& operand, ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder)); + TF_ASSIGN_OR_RETURN(auto literal, + ComputeConstantLiteral(client, operand, builder)); return LiteralUtil::Get(*literal, {}); } @@ -64,168 +102,193 @@ class ComputeConstantTest : public ClientLibraryTestBase { return result.ok() ? result.ValueOrDie() : false; } - template - void ExpectConstantComputedScalar(ComputationDataHandle operand, - Scalar expected, - ComputationBuilder* builder) { - Scalar computed = ComputeConstantScalar(operand, builder); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); - } + perftools::gputools::Platform* platform_; }; TEST_F(ComputeConstantTest, ScalarInt32Literal) { - ComputationBuilder b(client_, TestName()); - auto computation = b.ConstantR0(42); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 42); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.ConstantR0(42); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 42); + } } TEST_F(ComputeConstantTest, ScalarFloatAdd) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); + EXPECT_TRUE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } } TEST_F(ComputeConstantTest, ScalarRng) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), - ShapeUtil::MakeShape(F32, {})); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - ASSERT_FALSE(value.ok()) - << "computing a RNG value should not be considered a constant"; + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), + ShapeUtil::MakeShape(F32, {})); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + ASSERT_FALSE(value.ok()) + << "computing a RNG value should not be considered a constant"; + } } TEST_F(ComputeConstantTest, DirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } TEST_F(ComputeConstantTest, IndirectParam) { - ComputationBuilder b(client_, TestName()); - auto computation = - b.Add(b.ConstantR0(1.0f), - b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - EXPECT_FALSE(IsConstant(computation, &b)); - - auto value = ComputeConstantScalar(computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) - << value.status(); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = + b.Add(b.ConstantR0(1.0f), + b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); + EXPECT_FALSE(IsConstant(computation, &b)); + + auto value = ComputeConstantScalar(client, computation, &b); + EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) + .contains("depends on parameter")) + << value.status(); + } } // Test computation of an expression interspersed with param nodes but // the expression does not depend on the param nodes. TEST_F(ComputeConstantTest, UnrelatedParam) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); - auto constant_4 = b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); - auto not_constant_a = b.Add(constant_4, param_a); + auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); + auto constant_4 = + b.Add(b.ConstantR0(2.5f), b.ConstantR0(1.5f)); + auto not_constant_a = b.Add(constant_4, param_a); - auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); - auto constant_9 = b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); - auto not_constant_b = b.Add(param_b, constant_9); + auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1"); + auto constant_9 = + b.Mul(b.ConstantR0(2.0f), b.ConstantR0(4.5f)); + auto not_constant_b = b.Add(param_b, constant_9); - auto constant_13 = b.Add(constant_4, constant_9); - b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); + auto constant_13 = b.Add(constant_4, constant_9); + b.Add(not_constant_b, b.Add(constant_13, not_constant_a)); - EXPECT_TRUE(IsConstant(constant_13, &b)); + EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + auto value = ComputeConstantScalar(client, constant_13, &b); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 13.0f); + } } TEST_F(ComputeConstantTest, NonScalarAdd) { - ComputationBuilder b(client_, TestName()); + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); - auto computation = - b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); - EXPECT_TRUE(IsConstant(computation, &b)); + auto computation = + b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); + EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + auto computed = ComputeConstantLiteral(client, computation, &b); + ASSERT_TRUE(computed.ok()) << computed.status(); + std::unique_ptr expected_literal = + LiteralUtil::CreateR1({4, 6}); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } } TEST_F(ComputeConstantTest, IntegerDivide) { - ComputationBuilder b(client_, TestName()); - auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); - EXPECT_TRUE(IsConstant(computation, &b)); - - auto computed = ComputeConstantLiteral(computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); -} + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); + EXPECT_TRUE(IsConstant(computation, &b)); -XLA_TEST_F(ComputeConstantTest, Layout) { - ComputationBuilder b(client_, TestName()); - - std::vector> layouts = {{0, 1}, {1, 0}}; - for (const std::vector& layout : layouts) { - auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = - ComputeConstantLiteral(b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); + auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); - - std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); + std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } +XLA_TEST_F(ComputeConstantTest, Layout) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + + std::vector> layouts = {{0, 1}, {1, 0}}; + for (const std::vector& layout : layouts) { + auto layout_proto = LayoutUtil::MakeLayout(layout); + auto computed = ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto); + ASSERT_TRUE(computed.ok()) << computed.status(); + + std::unique_ptr expected_literal = + test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, + layout); + LiteralTestUtil::AssertEqualShapesAndLayouts( + expected_literal->shape(), computed.ValueOrDie()->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + } + } +} + // This test is permanently disabled on CPU because it requires that the // backend used for execution is different than the backend used for // ComputeConstant which is always cpu. TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { // Compute a trivial constant, then try to use the value in an Execute // call. This should fail because the constant resides on the CPU and the - // Execute call is executed on a different backend. - ComputationBuilder constant_b(client_, TestName()); + // Execute call is executed on a different backend. This test only makes + // sense with LocalClient, since CompileOnlyClient does not support + // execution. + Client* client = ClientOrDie(platform_, ClientType::kLocal); + ComputationBuilder constant_b(client, TestName()); auto constant = constant_b.ConstantR0(42); auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie(); - auto literal = client_->Transfer(*handle).ConsumeValueOrDie(); + auto literal = client->Transfer(*handle).ConsumeValueOrDie(); LiteralTestUtil::ExpectR0Equal(42, *literal); // Build trivial computation which takes one parameter. - ComputationBuilder b(client_, TestName()); + ComputationBuilder b(client, TestName()); b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0")); auto computation = b.Build().ConsumeValueOrDie(); // Try to use value from ComputeConstant in Execute. - auto execute_status = client_->Execute(computation, {handle.get()}); + auto execute_status = client->Execute(computation, {handle.get()}); EXPECT_FALSE(execute_status.ok()); - EXPECT_MATCH( + EXPECT_THAT( execute_status.status().error_message(), - testing::ContainsRegex("argument 0 is on device Host:0 but computation " - "will be executed on device")); + ::testing::ContainsRegex("argument 0 is on device Host:0 but computation " + "will be executed on device")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 9a48b19b96aea829ded626ddb4ac64c0fa42b64c..63bfac441d3c1f7aa257a7f9fc81df98f47551d5 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -34,6 +35,7 @@ namespace xla { namespace { using ConcatTest = ClientLibraryTestBase; +using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { @@ -41,9 +43,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { auto concatenated = builder.ConcatInDim({}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH( - computation_status.status().ToString(), - testing::ContainsRegex("Concatenate expects at least one argument")); + EXPECT_THAT(computation_status.status().ToString(), + HasSubstr("Concatenate expects at least one argument")); } // Concatenate with one argument works. @@ -56,6 +57,15 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto concatenated = builder.ConcatInDim({a}, 0); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { @@ -65,9 +75,8 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { auto concatenated = builder.ConcatInDim({a, b}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::ContainsRegex( - "dimension to concatenate along out of bounds: 0")); + EXPECT_THAT(computation_status.status().ToString(), + HasSubstr("dimension to concatenate along out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { @@ -404,10 +413,9 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { auto concatenated = builder.ConcatInDim({x, y}, 0); StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); - EXPECT_MATCH( + EXPECT_THAT( computation_status.status().ToString(), - testing::ContainsRegex( - "Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected non-opaque argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 9f38dc4b365672733ed773043f77bc4a3e8405ef..4aff6dc7d57f635fbb8a14c2bdeb5581e00119c9 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -25,12 +25,11 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -43,8 +42,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); - ASSERT_MATCH(dimension_numbers_status.status().error_message(), - testing::ContainsRegex("input are not unique")); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("input are not unique")); } // Tests the convolution operation with invalid weight dimension numbers. @@ -52,8 +51,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); - ASSERT_MATCH(dimension_numbers_status.status().error_message(), - testing::ContainsRegex("weight are not unique")); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("weight are not unique")); } XLA_TEST_F(ConvolutionDimensionNumbersTest, diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 29e29505333b64926cdd0b3e9fe7ef3407eaaec2..8ea97e67d640d97baa70cddf60f3336a8849552a 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -256,6 +257,22 @@ XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) { TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2}); } +using CopyOpClientTest = ClientLibraryTestBase; + +XLA_TEST_F(CopyOpClientTest, Copy0x0) { + Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); + Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); + auto empty = LiteralUtil::CreateFromShape(in_shape); + + ComputationBuilder builder(client_, TestName()); + auto param0 = builder.Parameter(0, in_shape, "input"); + auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + + auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) + .ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*empty, *actual); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index dc54c9defec2394049c38781a8d02fc8bd05158a..f7dcf68c1b63a2efeb226965dd3a09963e876f2a 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -29,23 +29,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" -extern "C" void __attribute__((visibility("default"))) -R0F32Add2(float* out, float** in) { + +extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void __attribute__((visibility("default"))) -R2F32ReduceSum(float* out, float** in) { +extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void __attribute__((visibility("default"))) -Add1ToValues(float* out, float** in) { +extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 528efd2942b0ebbba16faba2a0543a2694cd5c2a..cc3c4a2a5e115d7791e8574f4ead17f77dcd5e7c 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -21,15 +21,17 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { +using ::testing::HasSubstr; + class DeallocationTest : public ClientLibraryTestBase { protected: // Build and execute the given computation then verify the results can be @@ -50,7 +52,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); - // A result can be transfered an arbitrary number of times. Add an extra + // A result can be transferred an arbitrary number of times. Add an extra // transfer here so we're not just testing that a second call to Transfer // fails. ASSERT_IS_OK(client_->Transfer(*global_data).status()); @@ -59,8 +61,8 @@ TEST_F(DeallocationTest, DeallocateScalar) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } TEST_F(DeallocationTest, DeallocateVector) { @@ -72,8 +74,8 @@ TEST_F(DeallocationTest, DeallocateVector) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } TEST_F(DeallocationTest, DeallocateEmptyVector) { @@ -85,8 +87,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateTuple) { @@ -99,8 +101,8 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { @@ -114,8 +116,8 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { @@ -130,8 +132,8 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { auto transfer_status = client_->Transfer(*global_data); ASSERT_FALSE(transfer_status.ok()); - ASSERT_MATCH(transfer_status.status().error_message(), - testing::HasSubstr("was previously deallocated")); + ASSERT_THAT(transfer_status.status().error_message(), + HasSubstr("was previously deallocated")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 57a7c61b141f3e8c5cf3ecc7e34043a79129c01b..60ce2b1b58c6a3b93d394b1fcd8066313ec30e9d 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -34,6 +35,10 @@ limitations under the License. namespace xla { namespace { +using ::testing::ContainsRegex; +using ::testing::ElementsAre; +using ::testing::HasSubstr; + class DeconstructTupleTest : public ClientLibraryTestBase { protected: // Build and execute the given computation then verify the results can be @@ -63,9 +68,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { auto handles = result_status.ConsumeValueOrDie(); std::vector copy(4); ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -85,16 +90,16 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { std::vector copy(4); ASSERT_IS_OK(client_->TransferInProcess(*handles1[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles1[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); handles1[0].reset(); handles1[1].reset(); ASSERT_IS_OK(client_->TransferInProcess(*handles2[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles2[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -114,13 +119,13 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { std::vector copy(4); ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[3], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -140,17 +145,17 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { std::vector copy(4); ASSERT_IS_OK(client_->TransferInProcess(*handles[0], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[1], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({2.0, 4.0, 6.0, 8.0})); + EXPECT_THAT(copy, ElementsAre(2.0, 4.0, 6.0, 8.0)); ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); ASSERT_IS_OK(client_->TransferInProcess(*handles[2], ©[0])); - EXPECT_MATCH(copy, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(copy, ElementsAre(1.0, 2.0, 3.0, 4.0)); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -160,8 +165,8 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH(result_status.status().error_message(), - testing::ContainsRegex("global data handle .* is not a tuple")); + EXPECT_THAT(result_status.status().error_message(), + ContainsRegex("global data handle .* is not a tuple")); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { @@ -189,9 +194,8 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); - EXPECT_MATCH( - result_status.status().error_message(), - testing::ContainsRegex("deconstructing nested tuples not yet supported")); + EXPECT_THAT(result_status.status().error_message(), + HasSubstr("deconstructing nested tuples not yet supported")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 1d1fb337075855372ae54ac3c7e9abf55a6c32f1..cdb4498f4ed1e4f7fb2ad7a29a1cec4e26b76ed3 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -108,7 +109,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR1(const std::vector& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const std::vector& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -126,7 +127,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR2(const Array2D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array2D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -144,7 +145,7 @@ class DynamicSliceTest : public ClientLibraryTestBase { template void RunR3(const Array3D& input_values, const std::vector slice_starts, - const std::vector slice_sizes, + const std::vector& slice_sizes, const Array3D& expected_values) { ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 62878fed5549a6720a782d01c292ff143187e9a4..ca15f7395da79d7c5c05c03b6fafdca9e6953955 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -94,10 +94,10 @@ StatusOr HloTestBase::Execute( << LayoutUtil::HumanString(module_config->entry_computation_layout() .result_layout() .layout()); + hlo_module->set_config(*module_config); TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend_->compiler()->Compile(std::move(hlo_module), - std::move(module_config), test_hlo_dumper_, + backend_->compiler()->Compile(std::move(hlo_module), test_hlo_dumper_, backend_->default_stream_executor())); se::Stream stream(backend_->default_stream_executor()); @@ -111,8 +111,9 @@ StatusOr HloTestBase::Execute( backend_->eigen_intra_op_thread_pool_device()); HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options(run_options, - backend_->StreamBorrower()); + ServiceExecutableRunOptions service_run_options( + run_options, backend_->StreamBorrower(), + backend_->inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, @@ -123,9 +124,7 @@ StatusOr HloTestBase::Execute( *result_shape = executable->result_shape(); - // TODO(b/36256956) Ideally tuple elements could always be distinct buffers. - if (ShapeUtil::IsTuple(*result_shape) && - backend_->transfer_manager()->TupleElementsAreDistinctBuffers()) { + if (ShapeUtil::IsTuple(*result_shape)) { // We must record element buffers of tuples as well to avoid leaks. DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 6119473d8158fe87b3611a3edc3490058556288a..d94602ffda2ea6cba8d734b0d814ae5d0dbbd28d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -65,7 +65,7 @@ class HloTestBase : public ::testing::Test { perftools::gputools::DeviceMemoryBase TransferToDevice( const Literal& literal); - // Transfers the array refered to by the given handle from the device and + // Transfers the array referred to by the given handle from the device and // returns as a Literal. std::unique_ptr TransferFromDevice( const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); @@ -84,28 +84,6 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); - // Helpers for comparing ordered and unordered equality of HloInstruction - // containers. - void ExpectEqOrdered( - tensorflow::gtl::ArraySlice actual, - tensorflow::gtl::ArraySlice expected) { - std::vector expected_vec(expected.begin(), - expected.end()); - std::vector actual_vec(actual.begin(), actual.end()); - EXPECT_TRUE(testing::VectorMatcher(expected_vec)( - actual_vec)); - } - - void ExpectEqUnordered( - tensorflow::gtl::ArraySlice actual, - tensorflow::gtl::ArraySlice expected) { - std::vector expected_vec(expected.begin(), - expected.end()); - std::vector actual_vec(actual.begin(), actual.end()); - EXPECT_TRUE(testing::UnorderedElementsAre( - expected_vec)(actual_vec)); - } - string TestName() const; std::unique_ptr backend_; diff --git a/tensorflow/compiler/xla/tests/inprocess_service_test.cc b/tensorflow/compiler/xla/tests/inprocess_service_test.cc index ea0be07872f31b8e3357d91a164ce8727a159f63..97adf7ad6c974a3e258ee42b04f6be2e04c04e9d 100644 --- a/tensorflow/compiler/xla/tests/inprocess_service_test.cc +++ b/tensorflow/compiler/xla/tests/inprocess_service_test.cc @@ -26,13 +26,13 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -74,7 +74,7 @@ XLA_TEST_F(InProcessServiceTest, TransferFromServer) { std::vector result(3, 0); ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - EXPECT_MATCH(result, testing::VectorMatcher({1, 42, 5})); + EXPECT_THAT(result, ::testing::ElementsAre(1, 42, 5)); } XLA_TEST_F(InProcessServiceTest, TransferToServer) { @@ -148,7 +148,7 @@ XLA_TEST_F(InProcessServiceTest, ExecuteRowMajor) { Shape shape; ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - EXPECT_MATCH(result, testing::VectorMatcher({1.0, 2.0, 3.0, 4.0})); + EXPECT_THAT(result, ::testing::ElementsAre(1.0, 2.0, 3.0, 4.0)); } XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) { @@ -159,7 +159,7 @@ XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) { Shape shape; ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data())); - EXPECT_MATCH(result, testing::VectorMatcher({1.0, 3.0, 2.0, 4.0})); + EXPECT_THAT(result, ::testing::ElementsAre(1.0, 3.0, 2.0, 4.0)); } XLA_TEST_F(InProcessServiceTest, ExecuteAndReuseDifferentLayouts) { diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index f7bbc0f38bb501e042542cf7f0a3d4fadb3a2a23..23453db57bc4a5db0d3a4f7c327e3313333d1ae2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/io/path.h" @@ -76,11 +76,11 @@ string Hostname() { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { +::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); if (ulhs != urhs) { - return testing::AssertionFailure() << tensorflow::strings::Printf( + return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) @@ -90,33 +90,33 @@ testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { .c_str(), rhs, rhs); } - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } // Templated comparator that specializes for float equality comparison with the // bitwise helper above (this is the un-specialized fallback, to just use the // default gunit implementation). template -testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { +::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { if (lhs == rhs) { - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } ::testing::Message msg; msg << "Expected equality of these values:"; msg << "\n " << lhs; msg << "\n " << rhs; - return testing::AssertionFailure() << msg; + return ::testing::AssertionFailure() << msg; } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> -testing::AssertionResult CompareEqual(float lhs, float rhs) { +::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } template <> -testing::AssertionResult CompareEqual(double lhs, double rhs) { +::testing::AssertionResult CompareEqual(double lhs, double rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -130,7 +130,7 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = LiteralUtil::Get(expected, multi_index); NativeT actual_value = LiteralUtil::Get(actual, multi_index); - testing::AssertionResult result = + ::testing::AssertionResult result = CompareEqual(expected_value, actual_value); return result; // Defines implicit coersion to bool. } @@ -159,7 +159,7 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, EXPECT_FALSE(Equal(expected, actual)); } -/* static */ testing::AssertionResult LiteralTestUtil::Equal( +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << LiteralUtil::ToString(expected); VLOG(1) << "actual: " << LiteralUtil::ToString(actual); @@ -207,9 +207,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " << PrimitiveType_Name(expected.shape().element_type()); } - testing::AssertionResult result = testing::AssertionSuccess(); + ::testing::AssertionResult result = ::testing::AssertionSuccess(); if (!match) { - result = testing::AssertionFailure() + result = ::testing::AssertionFailure() << "expected: " << LiteralUtil::ToString(expected) << "\nactual: " << LiteralUtil::ToString(actual); VLOG(1) << result.message(); @@ -314,7 +314,7 @@ class NearComparator { private: // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occured to keep the size of the output + // track of how many mismatches have occurred to keep the size of the output // manageable. template bool ExpectValuesNear(NativeT expected, NativeT actual) { @@ -421,12 +421,12 @@ class NearComparator { } // namespace -/* static */ testing::AssertionResult LiteralTestUtil::Near( +/* static */ ::testing::AssertionResult LiteralTestUtil::Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) { NearComparator comparator(error); return comparator.ExpectNear(expected, actual) - ? testing::AssertionSuccess() - : testing::AssertionFailure() << "values were not near"; + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << "values were not near"; } /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, @@ -435,14 +435,14 @@ class NearComparator { EXPECT_TRUE(Near(expected, actual, error)); } -/* static */ testing::AssertionResult LiteralTestUtil::NearTuple( +/* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) { VLOG(1) << "expected: " << LiteralUtil::ToString(expected); VLOG(1) << "actual: " << LiteralUtil::ToString(actual); if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { - return testing::AssertionFailure() + return ::testing::AssertionFailure() << "tuples expected expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); @@ -469,7 +469,7 @@ class NearComparator { } } - return testing::AssertionSuccess(); + return ::testing::AssertionSuccess(); } /* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 85656a53e4400f2b0522e20a7b46922016432103..4f98083033310baf6ec95de0d2331d1aff8f3f7d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -57,7 +59,7 @@ class LiteralTestUtil { // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. - static testing::AssertionResult Equal( + static ::testing::AssertionResult Equal( const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; // Expects that expected and actual are Equal. @@ -101,7 +103,7 @@ class LiteralTestUtil { // Asserts that the expected and actual literals are within the given error // bound for all elements. Also, asserts that the rank, dimensions sizes, and // bounds are equivalent. Only supported for floating point values. - static testing::AssertionResult Near( + static ::testing::AssertionResult Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -147,7 +149,7 @@ class LiteralTestUtil { // tuples are within the given error bound. Tuples are matched recursively. // If the elements of the tuple are not floating-point types, the error spec // is ignored and exact equality is checked. - static testing::AssertionResult NearTuple( + static ::testing::AssertionResult NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -170,6 +172,36 @@ class LiteralTestUtil { tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal); + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); + private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; @@ -269,6 +301,40 @@ template ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); } +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); + TF_RETURN_IF_ERROR(LiteralUtil::Populate( + literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, + T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> +LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_ diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 7ea83a9e956ca8b5bb26ea6aaa844d2b63107328..52816dc72cc4d094054b2aea72f0cc63c7ff478d 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -42,7 +42,7 @@ xla::Computation Doubler(xla::Client* client) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - auto client = xla::ClientLibrary::LocalClientOrDie(); + auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); xla::ComputationBuilder builder(client, "aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); @@ -74,7 +74,7 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::LocalClient::AheadOfTimeComputationInstance instance{ + xla::CompileOnlyClient::AotComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 7fe4c9020f4c67ecc9888425cf0a2c358ad49e6d..7fcf687655a98d3ee972f8d3b784be655410a313 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -17,12 +17,19 @@ limitations under the License. #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -91,16 +98,34 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { return allocator_; } +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct LocalClientTestBase::EigenThreadPoolWrapper { + explicit EigenThreadPoolWrapper() + : pool(new tensorflow::thread::ThreadPool( + tensorflow::Env::Default(), "XLAEigenTest", /*num_threads=*/2)), + wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())), + device(new Eigen::ThreadPoolDevice(wrapper.get(), + wrapper->NumThreads())) {} + + std::unique_ptr pool; + std::unique_ptr wrapper; + std::unique_ptr device; +}; + LocalClientTestBase::LocalClientTestBase( perftools::gputools::Platform* platform) : local_client_( - ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) { + ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), + thread_pool_wrapper_(new EigenThreadPoolWrapper()) { stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform()) .ValueOrDie()[local_client_->default_device_ordinal()]; transfer_manager_ = TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie(); } +LocalClientTestBase::~LocalClientTestBase() {} + std::unique_ptr LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { return LiteralToScopedShapedBuffer(literal, @@ -190,8 +215,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; run_options.set_inter_op_thread_pool( local_client_->backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - local_client_->backend().eigen_intra_op_thread_pool_device()); + run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 4e7b05cea60887eec628ce9b4848321e721030e5..e3c3bb46cf26cc742b7abb39a3e457d823d829ec 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -74,8 +74,10 @@ class TestAllocator : public StreamExecutorMemoryAllocator { // A base class for tests which exercise the LocalClient interface. class LocalClientTestBase : public ::testing::Test { protected: + struct EigenThreadPoolWrapper; explicit LocalClientTestBase( perftools::gputools::Platform* platform = nullptr); + virtual ~LocalClientTestBase(); static TestAllocator* GetOrCreateAllocator( perftools::gputools::Platform* platform); @@ -142,6 +144,8 @@ class LocalClientTestBase : public ::testing::Test { TransferManager* transfer_manager_; LocalClient* local_client_; + + std::unique_ptr thread_pool_wrapper_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2433c5653a6562b9672eeff81192dfc3152dffed..3cfa89e2e7d8d145932f0ceca0df349da3695f38 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -529,9 +529,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_MATCH(computation_status.status().ToString(), - testing::HasSubstr("error from: ErrorAdd: binary op with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: binary op with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 0cd0f97b0621d771ae039f0be6bd6c67161b49a4..a0f98fcfef3b73f8ffff67ef679041197f78f5ba 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -55,7 +56,7 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); LiteralUtil::EachCell(*actual, [=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); @@ -75,7 +76,7 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options)); - EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions())); + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; LiteralUtil::EachCell( *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { @@ -193,7 +194,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { } } -// This tests demonstrates the global seeding behaviour. +// This tests demonstrates the global seeding behavior. // * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is // fixed (i.e., there is a single output for a given seed); // * If no seed is passed in then the output of every call can be different; diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 34fce21758b98c52831ac4ddb168d3e1538e9f1d..feb2b465fca6b1ffda190025568470e8daf297a3 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -61,7 +61,7 @@ namespace { class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { - // Implementation note: layed out z >> y >> x by default. + // Implementation note: laid out z >> y >> x by default. // clang-format off literal_2d_ = LiteralUtil::CreateR2({ // x0 x1 x2 @@ -211,9 +211,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); } XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); } XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); } XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); } -XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); } XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); } +XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); } XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); } XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); } XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); } @@ -221,6 +221,9 @@ XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); } XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) { RunR1ToR0Test(16 * 1024 + 1); } +XLA_TEST_F(ReduceTest, ReduceR1_64K_F32_To_R0) { RunR1ToR0Test(64 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_1M_F32_To_R0) { RunR1ToR0Test(1024 * 1024); } +XLA_TEST_F(ReduceTest, ReduceR1_16M_F32_To_R0) { RunR1ToR0Test(4096 * 4096); } XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); } XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 56501e43b5c5d965ea4305f2ca88909b253ed273..c3b768579a401706eff4a2a24d840da84080d26d 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -43,7 +43,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { public: ReduceWindowTest() : builder_(client_, TestName()) {} - void ReduceWindowAdd(ComputationDataHandle input, + void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -52,7 +52,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { window_dimensions, window_strides, padding); } - void ReduceWindowMax(ComputationDataHandle input, + void ReduceWindowMax(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -61,7 +61,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { CreateScalarMax(), window_dimensions, window_strides, padding); } - void ReduceWindowMin(ComputationDataHandle input, + void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -182,6 +182,7 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); } + // TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes. TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) { Array4D input_array(2, 2, 4, 16); @@ -368,6 +369,16 @@ TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) { ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); } +TEST_F(ReduceWindowTest, Add1x2In2x2Same) { + Array2D input_array({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto input = builder_.ConstantR2FromArray2D(input_array); + ReduceWindowAdd(input, {1, 2}, {1, 1}, Padding::kSame); + Array2D expected({ + {3.0f, 2.0f}, {7.0f, 4.0f}, + }); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 18e6e2d3f1d6aedb68f83b8058517398760c39ba..c5f20b9ca1db1812f52a4d6f568ff9093016a90b 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -31,13 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -68,6 +67,22 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); + a = builder.Neg(a); + auto reshape = + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + + ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); @@ -76,6 +91,24 @@ XLA_TEST_F(ReshapeTest, Trivial0x3) { ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-05-15 +// with an incorrect result rank. +XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr param0_literal = + LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); + auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); + + ComputeAndCompareR1(&builder, {}, {param0_data.get()}, + zero_error_spec_); +} + XLA_TEST_F(ReshapeTest, Trivial3x0) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); @@ -383,15 +416,15 @@ XLA_TEST_F(ReshapeTest, ToScalar) { XLA_TEST_F(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); b.Reshape(b.ConstantR1({1}), {}, {}); - EXPECT_MATCH(ExecuteToString(&b, {}), - testing::HasSubstr("dimensions not a permutation")); + EXPECT_THAT(ExecuteToString(&b, {}), + ::testing::HasSubstr("dimensions not a permutation")); } XLA_TEST_F(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); b.Reshape(b.ConstantR1({1, 2}), {1}, {}); - EXPECT_MATCH(ExecuteToString(&b, {}), - testing::HasSubstr("mismatched element counts")); + EXPECT_THAT(ExecuteToString(&b, {}), + ::testing::HasSubstr("mismatched element counts")); } XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 134eb91a1fedf8624363c273813fe2145f64aab7..ceee24c307ed0a71200ba6b17b17a90ab009cd2d 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -245,37 +246,183 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); - builder.Div(builder.ConstantR0(-5), builder.ConstantR0(2)); +struct DivS32Params { + int32 dividend; + int32 divisor; + int32 quotient; + int32 remainder; +}; - ComputeAndCompareR0(&builder, -2, {}); -} +void PrintTo(const DivS32Params& p, std::ostream* os) { + *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", " + << p.remainder << "}"; +} + +class DivS32Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + builder.Div(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); + + ComputeAndCompareR0(&builder, p.quotient, {}); +} + +XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + builder.Rem(builder.ConstantR0(p.dividend), + builder.ConstantR0(p.divisor)); + + ComputeAndCompareR0(&builder, p.remainder, {}); +} + +XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividendd = + CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); + auto divisord = + CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); + builder.Div(dividend, divisor); + + ComputeAndCompareR0(&builder, p.quotient, + {dividendd.get(), divisord.get()}); +} + +XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { + DivS32Params p = GetParam(); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle dividend; + ComputationDataHandle divisor; + auto dividendd = + CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); + auto divisord = + CreateR0Parameter(p.divisor, 1, "divisor", &builder, &divisor); + builder.Rem(dividend, divisor); + + ComputeAndCompareR0(&builder, p.remainder, + {dividendd.get(), divisord.get()}); +} + +INSTANTIATE_TEST_CASE_P( + DivS32Test_Instantiation, DivS32Test, + ::testing::Values( + // Positive divisors. + DivS32Params{5, 2, 2, 1}, // + DivS32Params{-5, 2, -2, -1}, // + DivS32Params{17, 3, 5, 2}, // + DivS32Params{-17, 3, -5, -2}, // + // Negative divisors. + DivS32Params{5, -2, -2, 1}, // + DivS32Params{-5, -2, 2, -1}, // + DivS32Params{17, -3, -5, 2}, // + DivS32Params{-17, -3, 5, -2}, // + // Large positive divisors. + DivS32Params{INT32_MIN, 7919, -271181, -1309}, // + DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, // + DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, // + DivS32Params{INT32_MIN, 0x40000000, -2, 0}, // + DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, // + // Large negative divisors. + DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, // + DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, // + DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, // + DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, // + DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, // + DivS32Params{INT32_MIN, -0x40000000, 2, 0}, // + DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff})); + +TEST_F(ScalarComputationsTest, DivU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on + + Computation div_computation; + { + ComputationBuilder builder(client_, TestName()); -TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(-5), builder.ConstantR0(2)); + ComputationDataHandle dividend = + builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + ComputationDataHandle divisor = + builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); + builder.Div(dividend, divisor); + TF_ASSIGN_OR_ASSERT_OK(div_computation, builder.Build()); + } - ComputeAndCompareR0(&builder, -1, {}); + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); + TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, + client_->TransferToServer(*divisor_literal)); + auto actual_literal = + client_ + ->ExecuteAndTransfer(div_computation, + {dividend_data.get(), divisor_data.get()}, + &execution_options_) + .ConsumeValueOrDie(); + auto expected_literal = + LiteralUtil::CreateR0(dividend / divisor); + LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + } + } + } } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(7919)); +TEST_F(ScalarComputationsTest, RemU32s) { + // clang-format off + // Some interesting values to test. + std::vector vals = { + 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; + // clang-format on - ComputeAndCompareR0(&builder, -1309, {}); -} + Computation rem_computation; + { + ComputationBuilder builder(client_, TestName()); -TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) { - ComputationBuilder builder(client_, TestName()); - builder.Rem(builder.ConstantR0(INT_MIN), - builder.ConstantR0(INT_MAX)); + ComputationDataHandle dividend = + builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); + ComputationDataHandle divisor = + builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); + builder.Rem(dividend, divisor); + TF_ASSIGN_OR_ASSERT_OK(rem_computation, builder.Build()); + } - ComputeAndCompareR0(&builder, -1, {}); + for (uint32 divisor : vals) { + if (divisor != 0) { + for (uint32 dividend : vals) { + auto dividend_literal = LiteralUtil::CreateR0(dividend); + auto divisor_literal = LiteralUtil::CreateR0(divisor); + TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, + client_->TransferToServer(*dividend_literal)); + TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, + client_->TransferToServer(*divisor_literal)); + auto actual_literal = + client_ + ->ExecuteAndTransfer(rem_computation, + {dividend_data.get(), divisor_data.get()}, + &execution_options_) + .ConsumeValueOrDie(); + auto expected_literal = + LiteralUtil::CreateR0(dividend % divisor); + LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + } + } + } } -TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) { +TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 4cff1990865bcf1214a403a6241accbf82f06d00..5a2333e3386acbca43e3311cb6a316e298af9612 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -247,6 +248,230 @@ TEST_F(WhileTest, WhileWithTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +// Tests two while nodes when the result type T is a Tuple and the second +// while node uses the result of the first while node which is used in two +// nodes. +// tuple> w0(0, vector(10, 0.0f)); +// w0 = while (get<0>(w0) < c1) { +// get<0>(w0) = get<0>(w0) + 1; +// get<1>(w0) = get<1>(w0) + vector(10, 1.0f); +// } +// tuple> w1(get<0>(w0), get<1>(w0)); +// w1 = while (get<0>(w1) < c2) { +// get<0>(w1) = get<0>(w1) + 1; +// get<1>(w1) = get<1>(w1) + vector(10, 1.0f); +// } +// result = get<1>(w0) + get<1>(w1) +TEST_F(WhileTest, TwoWhileWithTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + Computation body2; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body2, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + + auto while2 = builder.While(condition2, body2, while1); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Test while nodes that share the while body computation. +TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + + auto while2 = builder.While(condition2, body, while1); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + +// Test while nodes that share the while body computation. +// TODO(b/37245345): Fails on GPU backend. +TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F32, {10})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + const int c1 = 5; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c1)); + TF_ASSIGN_OR_ASSERT_OK(condition, builder.Build()); + } + + Computation condition2; + const int c2 = 7; + { + ComputationBuilder builder(client_, "condition2"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Lt(iteration, builder.ConstantR0(c2)); + TF_ASSIGN_OR_ASSERT_OK(condition2, builder.Build()); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and add a constant vector of 1.0f to + // the weight variable, both of which are tuple elements. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto weights = builder.GetTupleElement(prev, 1); + auto input = builder.ConstantR1(10, 1.f); + auto new_weights = builder.Add(weights, input); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); + TF_ASSIGN_OR_ASSERT_OK(body, builder.Build()); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); + auto while1 = builder.While(condition, body, init); + auto while2 = builder.While(condition2, body, init); + + auto while_result1 = builder.GetTupleElement(while1, 1); + auto while_result2 = builder.GetTupleElement(while2, 1); + VLOG(2) << "while_result2 = " + << ShapeUtil::HumanString( + *builder.GetShape(while_result2).ConsumeValueOrDie()); + auto result = builder.Add(while_result1, while_result2); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + const float sum = c1 + c2; + std::vector expected(10, sum); + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // WhileTest that uses DynamicUpdateSlice instruction in body computation. // Loop state tuple element 1 has as its single user operand(0) of // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU. @@ -315,7 +540,8 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]); // } // -// This test misuses a vector to represent a pair: +// This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a +// pair: // ((iteration, (random vector))). // // Note: this test currently only tests generating random values within a loop. diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 94d0f2646b15930f78c44fbb3d2b49fd6033a545..a167d80f73b0273739e22d94be8d90ab00839dc9 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 9dce4d13bb0e21d399795c5310e30b7ab64ea4ea..177ae4ea036af660b7a2be1d4082b30ca8fb9fac 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 46eab7f02bb12ca39e5713e7b0f96bfa178e9102..535e5b605b4f68671c9b6a8af4a12732f88e744e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -153,6 +153,7 @@ cc_binary( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -176,6 +177,24 @@ cc_binary( ], ) +cc_binary( + name = "dumped_computation_to_tf_graphdef", + srcs = ["dumped_computation_to_tf_graphdef.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 8b96e13489774539b50022808975db56c5ddc6f7..dc5a86f34e5b975fd8ba565d54e5c2c0b70bf53e 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -50,23 +51,38 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } Computation computation = computation_status.ConsumeValueOrDie(); - std::unique_ptr program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); + if (compile) { + std::unique_ptr program_shape = + client->GetComputationShape(computation).ConsumeValueOrDie(); - std::vector layouts; - for (int i = 0; i < program_shape->parameters_size(); ++i) { - layouts.push_back(&program_shape->parameters(i)); - } - StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + std::vector layouts; + for (int i = 0; i < program_shape->parameters_size(); ++i) { + layouts.push_back(&program_shape->parameters(i)); + } + StatusOr> executable = + local_service->CompileExecutable( + computation.handle(), layouts, &program_shape->result(), + /*device_ordinal=*/0, /*has_hybrid_result=*/true); + + const HloModule& module = executable.ValueOrDie()->module(); - const HloModule& module = executable.ValueOrDie()->module(); + fprintf(stdout, "HLO compiled for %s backend:\n%s\n", + local_service->backend().platform()->Name().c_str(), + module.ToString().c_str()); + } else { + const ComputationTracker& tracker = local_service->computation_tracker(); + UserComputation* user_computation = + tracker.Resolve(computation.handle()).ConsumeValueOrDie(); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + std::unique_ptr module = + tracker + .BuildHloModule(versioned_handle, + /*config=*/nullptr) + .ConsumeValueOrDie(); - fprintf(stdout, "HLO for %s backend:\n%s\n", - local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); + fprintf(stdout, "%s\n", module->ToString().c_str()); + } } } @@ -74,10 +90,21 @@ void RealMain(tensorflow::gtl::ArraySlice args) { } // namespace xla int main(int argc, char** argv) { - tensorflow::port::InitMain(argv[0], &argc, &argv); + bool compile = false; + std::vector flag_list = { + {"compile", &compile, + "If true, compile the computation using the default client before " + "dumping the HLO. Otherwise dump the raw (uncompiled) HLO."}, + }; + const xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args); + xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc new file mode 100644 index 0000000000000000000000000000000000000000..850267d3195785a96bf8d2c80fe64fdb8aae0a91 --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* +// +// Dumps a tensorflow GraphDef in text format for a snapshot computation. The +// dumped graph is an HLO computation with HLO instructions as nodes and can be +// visualized on Tensorboard. Upload the dumped files on Tensorboard. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::Env; + +namespace xla { +namespace tools { + +void RealMain(tensorflow::gtl::ArraySlice args) { + Client* client = ClientLibrary::LocalClientOrDie(); + for (char* arg : args) { + SessionModule module; + TF_CHECK_OK( + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); + Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + ComputationStats stats = + client->GetComputationStats(computation).ConsumeValueOrDie(); + fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); + flags->xla_generate_hlo_graph = ".*"; + + xla::legacy_flags::HloGraphDumperFlags* dumper_flags = + xla::legacy_flags::GetHloGraphDumperFlags(); + dumper_flags->xla_hlo_dump_as_graphdef = true; + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args); + return 0; +} diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 8258031a2c5119d085a483a0826f7284897dcee3..ea8b4b7b989b72034f33920a7d8c1a75e15a7dd1 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" +#include + namespace xla { using ::tensorflow::string; @@ -32,6 +35,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using ::Eigen::half; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index a711b5035d842cd26945b2dac1159392813d56ab..d467178cb528a93b2c1030fc72d054cc0edf95b6 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -33,7 +33,7 @@ namespace { // Adds a backtrace to the provided status iff the xla_status_add_backtrace flag // is set. This is useful for quickly tracing status errors observed coming out // of the service. -Status MaybeAddBacktrace(Status prior) { +Status MaybeAddBacktrace(const Status& prior) { DCHECK(!prior.ok()); if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) { return Status{prior.code(), @@ -153,16 +153,26 @@ string Reindent(tensorflow::StringPiece original, }); } +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { + if (rank != permutation.size()) { + return false; + } + std::vector output(permutation.size(), -1); + for (auto index : permutation) { + CHECK_GE(index, 0); + CHECK_LT(index, rank); + output[index] = 0; + } + return std::find(output.begin(), output.end(), -1) == output.end(); +} + std::vector InversePermutation( tensorflow::gtl::ArraySlice input_permutation) { + DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { output_permutation[input_permutation[i]] = i; } - DCHECK_EQ( - 0, std::count(output_permutation.begin(), output_permutation.end(), -1)); - DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(), - output_permutation.begin())); return output_permutation; } @@ -196,6 +206,15 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { return padding_config; } +bool HasInteriorPadding(const PaddingConfig& config) { + for (const auto& dim : config.dimensions()) { + if (dim.interior_padding() != 0) { + return true; + } + } + return false; +} + string HumanReadableNumFlops(double flops, double nanoseconds) { if (nanoseconds == 0) { return "NaN FLOP/s"; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 55a66a7499571b4979ff375a8199cb329a799ef7..42d5c1d15501fb912551a044414e6fa0c83283b8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -38,6 +39,13 @@ limitations under the License. namespace xla { +// Ranks greater than 8 are very rare, so use InlinedVector to store +// the bounds and indices. And for the rare cases of ranks greater than 8, +// the InlinedVector will just behave like an std::vector<> and allocate the +// memory to store its values. +static constexpr int kInlineRank = 8; +using DimensionVector = tensorflow::gtl::InlinedVector; + // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. @@ -120,6 +128,14 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2) { std::equal(std::begin(c1), std::end(c1), std::begin(c2))); } +template +bool ContainersEqual(const Container1T& c1, + std::initializer_list il) { + tensorflow::gtl::ArraySlice c2{il}; + return ContainersEqual(c1, c2); +} + // Compares two containers for equality. Returns true iff the two containers // have the same size and all their elements compare equal using the predicate // p. Like std::equal, but forces size equality. @@ -130,6 +146,18 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, std::equal(std::begin(c1), std::end(c1), std::begin(c2), p)); } +// Performs a copy of count values from src to dest, using different strides for +// source and destination. The source starting index is src_base, while the +// destination one is dest_base. +template +void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, + int64 dest_stride, tensorflow::gtl::ArraySlice src, + int64 src_base, int64 src_stride, int64 count) { + for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { + dest[dest_base] = static_cast(src[src_base]); + } +} + // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. @@ -156,6 +184,9 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); string Reindent(tensorflow::StringPiece original, tensorflow::StringPiece indentation); +// Checks whether permutation is a permutation of the [0, rank) integer range. +bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); + // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. // @@ -166,12 +197,11 @@ template